diff --git a/ge/ge_local_engine/engine/host_cpu_engine.cc b/ge/ge_local_engine/engine/host_cpu_engine.cc index 488a5ee8..d9b67736 100755 --- a/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -14,14 +14,14 @@ * limitations under the License. */ #include "ge_local_engine/engine/host_cpu_engine.h" -#include "graph/common/omg_util.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_adapter.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/type_utils.h" #include "register/op_kernel_registry.h" #include "register/host_cpu_context.h" #include "common/ge/ge_util.h" #include "common/ge/plugin_manager.h" -#include "graph/utils/type_utils.h" #include "common/fp16_t.h" #include "common/math/math_util.h" @@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) { } Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr &op_kernel) { - std::string op_type; - auto status = GetOriginalType(node, op_type); - GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status); - + const std::string op_type = NodeUtils::GetNodeType(node); auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); if (kernel == nullptr) { GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index c050875e..f6de6ef0 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -289,10 +289,6 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n return SUCCESS; } - if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. - node->GetOpDesc()->SetType(IDENTITY); - } - std::unique_ptr new_node; GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 5c3d7db3..250562ce 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -15,9 +15,7 @@ */ #include "hybrid/model/node_item.h" -#include -#include "framework/common/debug/log.h" -#include "graph/common/omg_util.h" + #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "hybrid/executor/worker/shape_inference_engine.h" @@ -98,8 +96,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type)); + const std::string node_type = NodeUtils::GetNodeType(node); if (node_type == DATA) { GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); } else if (node_type == kNodeTypeRetVal) { @@ -409,8 +406,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { if (switch_index < switch_groups_.size()) { - std::vector &switch_group = switch_groups_[switch_index]; - switch_group.emplace_back(node_item); + auto &switch_group = switch_groups_[switch_index]; + switch_group.emplace(node_item); } else { ctrl_send_.insert(node_item); } @@ -433,8 +430,8 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { } // this is StreamMerge node, node_item is StreamActive node. - std::vector &switch_group = switch_groups_[merge_index]; - switch_group.emplace_back(node_item); + auto &switch_group = switch_groups_[merge_index]; + switch_group.emplace(node_item); node_item->ctrl_send_.emplace(this); GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index ec66f094..12775b00 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -155,7 +155,7 @@ struct NodeItem { std::map data_recv_; // Recv data notify from std::set ctrl_send_; // Send ctrl notify to std::set ctrl_recv_; // Recv ctrl notify from - std::vector> switch_groups_; // Send ctrl notify to + std::vector> switch_groups_; // Send ctrl notify to std::shared_ptr kernel_task; std::unique_ptr fused_subgraph; diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc index b8819a42..3f887819 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -342,6 +342,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function do GE_CHK_RT_RET(rtEventDestroy(evt)); } GELOGI("rdma callback success."); + return SUCCESS; }; HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc index 5cd971df..d52f56b9 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.cc +++ b/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -17,13 +17,9 @@ #include "hybrid/node_executor/rts/rts_node_executor.h" #include "hybrid/node_executor/rts/rts_task_factory.h" -#include "framework/common/debug/log.h" #include "common/ge/ge_util.h" -#include "framework/common/types.h" -#include "graph/common/omg_util.h" #include "graph/utils/tensor_utils.h" #include "hybrid/model/hybrid_model.h" -#include "runtime/rt.h" namespace ge { namespace hybrid { @@ -133,8 +129,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { GE_CHECK_NOTNULL(node); GELOGD("[%s] Load for local task.", node->GetName().c_str()); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); + const std::string node_type = NodeUtils::GetNodeType(node); RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); if (rts_task == nullptr) { GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str());