From 48885521f9390edd603bf7c096a8cc26295cb120 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Fri, 7 May 2021 22:24:50 +0800 Subject: [PATCH] Fix DSP and MemcpyAsyncNodeTask --- ge/graph/partition/dynamic_shape_partition.cc | 2 +- .../mark_force_unknown_for_cond_pass.cc | 40 +++++++++++++------ .../passes/mark_force_unknown_for_cond_pass.h | 10 ++++- ge/hybrid/node_executor/rts/rts_node_task.cc | 23 ++++++----- 4 files changed, 52 insertions(+), 23 deletions(-) diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc index 0f2a34f4..516d06d1 100755 --- a/ge/graph/partition/dynamic_shape_partition.cc +++ b/ge/graph/partition/dynamic_shape_partition.cc @@ -357,7 +357,7 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() { continue; } for (const auto &in_cluster : cluster->Inputs()) { - if (!in_cluster->IsUnknownShape()) { + if (!in_cluster->IsUnknownShape() || in_cluster->IsControlFlow()) { continue; } auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index d0b9af7e..6729a647 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -37,6 +37,7 @@ inline bool IsMergeInLoop(const NodePtr &node) { Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { GELOGD("MarkForceUnknownForCondPass Enter"); + std::map> switch_groups; for (const auto &node : graph->GetDirectNode()) { std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); @@ -44,20 +45,15 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { continue; } - const auto op_desc = node->GetOpDesc(); - if (!op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) && !IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) { - GELOGI("Merge[%s] has known shape, no need check switch", node->GetName().c_str()); - continue; - } - const auto &all_in_nodes = node->GetInDataNodes(); if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsMergeInLoop)) { continue; // LoopCond marked in NextIterationPass. } - MarkUnknownForSwitch(node); + MarkUnknownForSwitch(node, switch_groups[node]); } + MarkUnknownForSwitch(switch_groups); GELOGD("MarkForceUnknownForCondPass Leave"); return SUCCESS; } @@ -65,13 +61,12 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { /// /// @brief Mark force unknown shape for Switch node /// @param [in] merge node +/// @param [out] switch group /// @return /// -void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) { +void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector &switch_group) { // Switch --> {Switch --> Merge} --> Merge - std::vector switch_group; std::unordered_set nodes_seen; - std::queue> search_queue({{node, 0}}); while (!search_queue.empty()) { const auto dst_node = search_queue.front().first; @@ -117,9 +112,30 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) { } } } +} - for (const auto &n : switch_group) { - MarkForceUnknownShape(n, true); +/// +/// @brief Mark force unknown shape for Switch node +/// @param [in] switch groups +/// @return +/// +void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map> &switch_groups) { + std::function callback = [](const NodePtr &n) { + return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); + }; + + for (const auto &group : switch_groups) { + const auto &node = group.first; + const auto &switch_group = group.second; + const auto &op_desc = node->GetOpDesc(); + if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) || + std::any_of(switch_group.begin(), switch_group.end(), callback)) { + GELOGI("Mark [%s] as force unknown shape", node->GetName().c_str()); + MarkForceUnknownShape(node, true); + for (const auto &n : switch_group) { + MarkForceUnknownShape(n, true); + } + } } } } // namespace ge diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.h b/ge/graph/passes/mark_force_unknown_for_cond_pass.h index 65e09394..528a8fdc 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.h +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.h @@ -28,9 +28,17 @@ class MarkForceUnknownForCondPass : public GraphPass { /// /// @brief Mark force unknown shape for Switch node /// @param [in] merge node + /// @param [out] switch group /// @return /// - void MarkUnknownForSwitch(const NodePtr &node); + void MarkUnknownForSwitch(const NodePtr &node, std::vector &switch_group); + + /// + /// @brief Mark force unknown shape for Switch node + /// @param [in] switch groups + /// @return + /// + void MarkUnknownForSwitch(const std::map> &switch_groups); }; } // namespace ge #endif // GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ diff --git a/ge/hybrid/node_executor/rts/rts_node_task.cc b/ge/hybrid/node_executor/rts/rts_node_task.cc index 94566fc6..f6d6ddb6 100644 --- a/ge/hybrid/node_executor/rts/rts_node_task.cc +++ b/ge/hybrid/node_executor/rts/rts_node_task.cc @@ -169,17 +169,22 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); - const auto in_x = task_context.GetInput(0); // x - GE_CHECK_NOTNULL(in_x); - const auto out_y = task_context.MutableOutput(0); // value_index - GE_CHECK_NOTNULL(out_y); - - GELOGD("[%s] input size: %zu, output size: %zu", task_context.GetNodeName(), in_x->GetSize(), out_y->GetSize()); - if (in_x->GetSize() > 0 && out_y->GetSize() > 0) { - GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), in_x->GetData(), in_x->GetSize(), + auto input_desc = task_context.MutableInputDesc(0); + GE_CHECK_NOTNULL(input_desc); + int64_t copy_size = 0; + GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size)); + // copy_size would not be negative since GetTensorSizeInBytes returned successfully. + if (copy_size > 0) { + const auto in_v = task_context.MutableInput(0); + const auto out_v = task_context.MutableOutput(0); + GE_CHECK_NOTNULL(in_v); + GE_CHECK_NOTNULL(out_v); + GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(), + in_v->GetSize(), out_v->GetSize(), copy_size); + GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size, RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream())); } else { - GELOGW("[%s] invalid copy size, src: %zu, dst: %zu", task_context.GetNodeName(), in_x->GetSize(), out_y->GetSize()); + GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size); } if (done_callback) {