diff --git a/ge/common/profiling/profiling_manager.cc b/ge/common/profiling/profiling_manager.cc index aad2bbe3..1fc4dba6 100644 --- a/ge/common/profiling/profiling_manager.cc +++ b/ge/common/profiling/profiling_manager.cc @@ -27,6 +27,8 @@ namespace { const char *const kTrainingTrace = "training_trace"; const char *const kFpPoint = "fp_point"; const char *const kBpPoint = "bp_point"; + +#ifdef DAVINCI_SUPPORT_PROFILING const size_t kReportMaxLen = 2048; const int32_t kMaxDeviceNum = 256; const std::string kConfigNumsdev = "devNums"; @@ -35,6 +37,7 @@ const std::string kProfStart = "prof_start"; const std::string kProfStop = "prof_stop"; const std::string kProfModelSubscribe = "prof_model_subscribe"; const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe"; +#endif } // namespace namespace ge { @@ -110,7 +113,7 @@ ge::Status ProfilingManager::InitFromOptions(const Options &options, MsprofGeOpt } // enable profiling by env is_execute_profiling_ = true; - GELOGI("The profiling in env is %s, %s", env_profiling_mode, prof_conf.options); + GELOGI("The profiling in env is %s, %s", env_profiling_mode, prof_conf.options); } if (!is_execute_profiling_) { @@ -186,7 +189,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret); } } - + // stop profiling if (prof_cb_.msprofCtrlCallback == nullptr) { GELOGE(ge::PARAM_INVALID, "MsprofCtrlCallback callback is nullptr."); @@ -801,7 +804,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::CallMs if (prof_cb_.msprofReporterCallback == nullptr) { GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr."); return ge::PARAM_INVALID; - } + } return prof_cb_.msprofReporterCallback( static_cast(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK), static_cast(MsprofReporterCallbackType::MSPROF_REPORTER_REPORT), @@ -853,7 +856,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpP return; } } - + return; } diff --git a/ge/graph/passes/subgraph_const_migration_pass.cc b/ge/graph/passes/subgraph_const_migration_pass.cc index 579b2424..f131942c 100644 --- a/ge/graph/passes/subgraph_const_migration_pass.cc +++ b/ge/graph/passes/subgraph_const_migration_pass.cc @@ -20,11 +20,12 @@ #include "graph/passes/folding_pass.h" namespace ge { -constexpr uint32_t kDataOutIndex = 0; +constexpr uint32_t kZeroIndex = 0; constexpr uint32_t kCaseInputBase = 1; constexpr uint32_t kInvalidParent = 0x7fffffffU; +const string kMbatchNodeNameMark = "_ascend_mbatch_batch_"; -bool IsSameOpNode(const NodePtr &src_node, const NodePtr &dst_node) { +bool IsSameConstNode(const NodePtr &src_node, const NodePtr &dst_node) { if ((src_node == nullptr) && (dst_node == nullptr)) { return true; } @@ -37,35 +38,9 @@ bool IsSameOpNode(const NodePtr &src_node, const NodePtr &dst_node) { return false; } - if ((src_node->GetInControlNodes().size() != dst_node->GetInControlNodes().size()) || - (src_node->GetOutDataNodesSize() != dst_node->GetOutDataNodesSize())) { - return false; - } - - set related_parent; - const auto in_nodes = src_node->GetInControlNodes(); - for (uint32_t i = 0; i < in_nodes.size(); ++i) { - const auto owner_node = in_nodes.at(i); - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(owner_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - return false; - } - - related_parent.insert(parent_index); - } - - for (const auto &in_node : dst_node->GetInControlNodes()) { - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - return false; - } - - if (related_parent.count(parent_index) == 0) { - return false; - } - } - - return true; + const GeTensorDesc &src_desc = src_node->GetOpDesc()->GetOutputDesc(kZeroIndex); + const GeTensorDesc &dst_desc = dst_node->GetOpDesc()->GetOutputDesc(kZeroIndex); + return (src_desc == dst_desc); } /*********************************************************************************************************************** @@ -89,12 +64,12 @@ bool IsSameOpNode(const NodePtr &src_node, const NodePtr &dst_node) { +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ | Data | | Data | | Data | | Data | | Data | | Data | | Conv2D | +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ - \ \ | / / | | - \ \ | / / | | - \ \ | / / | | - \ \ | / / | | - \ +-----------+ / | +-----------+ - +---------------| Const |----------------+ | | Pooling | + \ \ | / / | | +-----------+ + \ \ | / / | | | Const | + \ \ | / / | | +-----------+ + \ \ | / / | | / + \ +-----------+ / | +-----------+ / + +---------------| Const |----------------+ | | Pooling |-----+ +-----------+ | +-----------+ \ | / \ | / @@ -126,28 +101,26 @@ Status SubgraphConstMigrationPass::Run(ComputeGraphPtr graph) { continue; } - do { - migration_append_ = false; - map> graph_datas; - if (ClassifyDataNodes(graph, func_desc, graph_datas) != SUCCESS) { - return FAILED; - } + map> all_const_nodes; + map> all_data_nodes; + if (ClassifyGraphNodes(graph, func_desc, all_const_nodes, all_data_nodes) != SUCCESS) { + return FAILED; + } - if (graph_datas.empty()) { - GELOGW("Graph: %s subgraph is empty", graph->GetName().c_str()); - break; - } + if (all_const_nodes.empty()) { + GELOGW("Graph: %s subgraph is empty", graph->GetName().c_str()); + break; + } - // {subgraph0, {{1, Data}, {2, Data}, {3, Data}, {4, Data}, ..., {n, Data}}} - // {subgraph1, {{1, Data}, {2, Data}, {3, Data}, {4, Data}, ..., {n, Data}}} - // {subgraph2, {{1, Data}, {2, Data}, {3, Data}, {4, Data}, ..., {n, Data}}} - const auto base_nodes = graph_datas.begin()->second; // Need copy. - for (const auto &node_item : base_nodes) { - if (GraphNodeMigration(graph, node, graph_datas, node_item.second, node_item.first) != SUCCESS) { - return FAILED; - } + // {subgraph0, {{key1, Const}, {key2, Const}, {key3, Const}, {key4, Const}, ..., {keyn, Const}}} + // {subgraph1, {{key1, Const}, {key2, Const}, {key3, Const}, {key4, Const}, ..., {keyn, Const}}} + // {subgraph2, {{key1, Const}, {key2, Const}, {key3, Const}, {key4, Const}, ..., {keyn, Const}}} + const auto &const_nodes = all_const_nodes.begin()->second; + for (const auto &item : const_nodes) { + if (GraphNodeMigration(graph, node, all_const_nodes, all_data_nodes, item.second, item.first) != SUCCESS) { + return FAILED; } - } while (migration_append_); + } } return SUCCESS; @@ -155,14 +128,16 @@ Status SubgraphConstMigrationPass::Run(ComputeGraphPtr graph) { /// /// @ingroup ge -/// @brief Get all Data nodes for all subgraph. +/// @brief Get all Const/Data nodes for all subgraph. /// @param [in] graph: Root compute graph. /// @param [in] func_desc: functional OpDesc of Case. -/// @param [out] graph_datas: Data groups of subgraph. +/// @param [out] all_const_nodes: Const groups of subgraph. +/// @param [out] all_data_nodes: Data groups of subgraph. /// @return 0: SUCCESS / others: FAILED /// -Status SubgraphConstMigrationPass::ClassifyDataNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, - map> &graph_datas) { +Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, + map> &all_const_nodes, + map> &all_data_nodes) { for (const auto &name : func_desc->GetSubgraphInstanceNames()) { const auto &subgraph = graph->GetSubgraph(name); if (subgraph == nullptr) { @@ -170,32 +145,47 @@ Status SubgraphConstMigrationPass::ClassifyDataNodes(const ComputeGraphPtr &grap return GE_GRAPH_EMPTY_SUBGRAPH; } - auto &data_nodes = graph_datas[subgraph]; - for (auto &data : subgraph->GetDirectNode()) { - if (data->GetType() != DATA) { - continue; - } + auto &data_nodes = all_data_nodes[subgraph]; + auto &const_nodes = all_const_nodes[subgraph]; + for (auto &node : subgraph->GetDirectNode()) { + if (node->GetType() == DATA) { + uint32_t parent_index = kInvalidParent; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + return FAILED; + } - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(FAILED, "Parent index not found, name: %s", data->GetName().c_str()); - return FAILED; - } + data_nodes[parent_index] = node; + GELOGD("%s, index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, node->GetName().c_str()); + } else if ((node->GetType() == CONSTANT) && (node->GetOutDataAnchor(kZeroIndex) != nullptr)) { + set peer_name_list; + const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex); + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + const auto &peer_node = in_anchor->GetOwnerNode(); + // Trim subgraph node name prefix. + string node_full_name = peer_node->GetName(); + size_t pos = node_full_name.find(kMbatchNodeNameMark); + if (pos == string::npos) { + GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); + return FAILED; + } + + string fixed_name = node_full_name.substr(0, pos); + pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length()); + if (pos != string::npos) { + fixed_name += node_full_name.substr(pos); + } + + peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx())); + } - data_nodes[parent_index] = data; - GELOGD("%s, Parent index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, data->GetName().c_str()); - } - } + string key_of_const; + for (const string &name : peer_name_list) { + key_of_const += (key_of_const.empty() ? name : "_" + name); + } - auto iter = graph_datas.begin(); - if (iter == graph_datas.end()) { - return SUCCESS; - } - for (const auto &data_nodes : graph_datas) { - if (data_nodes.second.size() != iter->second.size()) { - GELOGE(FAILED, "Subgraph %s has invalid Data nodes[%zu != %zu]", - data_nodes.first->GetName().c_str(), data_nodes.second.size(), iter->second.size()); - return FAILED; + const_nodes[key_of_const] = node; + GELOGD("%s, Key: %s, Const: %s", subgraph->GetName().c_str(), key_of_const.c_str(), node->GetName().c_str()); + } } } @@ -204,36 +194,27 @@ Status SubgraphConstMigrationPass::ClassifyDataNodes(const ComputeGraphPtr &grap /// /// @ingroup ge -/// @brief Get all Data nodes for all subgraph. -/// @param [in] node: Const node of subgraph. -/// @param [out] inputs: parent index to Const. -/// @param [out] outputs: Data groups of subgraph. +/// @brief Get parent_index for Const node migration. +/// @param [in] all_data_nodes: Data groups of subgraph. +/// @param [in] const_node: Const node will process. +/// @param [out] parent_index: parent index for replace Data. /// @return true: SUCCESS / false: FAILED /// -bool SubgraphConstMigrationPass::GetAssociatedNodes(const NodePtr &node, map &inputs, - map &outputs) { - for (uint32_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { - outputs[i] = kInvalidParent; - } - - uint32_t out_index = 0; - const auto in_nodes = node->GetInAllNodes(); - for (size_t i = 0; i < in_nodes.size(); ++i) { - const auto owner_node = in_nodes.at(i); - if (owner_node->GetType() != DATA) { +bool SubgraphConstMigrationPass::GetAssociatedNodes(const map> &all_data_nodes, + const NodePtr &const_node, uint32_t &parent_index) { + for (const auto in_node : const_node->GetInAllNodes()) { + if (in_node->GetType() != DATA) { return false; } - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(owner_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + uint32_t node_index = 0; + if (!AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, node_index)) { return false; } // Input Data feed other Node, need add new Data. - inputs[i] = parent_index; - if ((out_index == outputs.size()) && owner_node->GetOutDataNodes().empty()) { - outputs[out_index] = parent_index; - ++out_index; + if ((parent_index == kInvalidParent) && in_node->GetOutDataNodes().empty()) { + parent_index = node_index; } } @@ -242,43 +223,26 @@ bool SubgraphConstMigrationPass::GetAssociatedNodes(const NodePtr &node, map> &graph_datas, - const NodePtr &const_node, uint32_t parent_index, size_t index) { - auto it = graph_datas.begin(); - for (++it; it != graph_datas.end(); ++it) { - const auto &data_nodes = it->second; - auto data_it = data_nodes.find(parent_index); - if (data_it == data_nodes.end()) { - GELOGE(FAILED, "Data: %s not fount, index: %u", const_node->GetName().c_str(), parent_index); - return false; - } - - const auto &work_data = data_it->second; - const auto &out_anchor = work_data->GetOutControlAnchor(); - const auto &in_anchors = out_anchor->GetPeerInControlAnchors(); - if (in_anchors.size() <= index || in_anchors.at(index) == nullptr) { - GELOGW("Node anchors not same, Data: %s -> %s anchor size: %zu, index: %zu", - work_data->GetName().c_str(), const_node->GetName().c_str(), in_anchors.size(), index); - return false; - } - - const auto &in_anchor = in_anchors.at(index); - const auto &work_node = in_anchor->GetOwnerNode(); - if (work_node == nullptr) { - GELOGE(FAILED, "Data: %s not found, parent: %u, index: %zu", const_node->GetName().c_str(), parent_index, index); +bool SubgraphConstMigrationPass::IsParallelNodeSame(const map> &all_const_nodes, + const NodePtr &const_node, const string &node_key) { + auto it = all_const_nodes.begin(); + for (++it; it != all_const_nodes.end(); ++it) { + const auto &const_nodes = it->second; + auto node_it = const_nodes.find(node_key); + if (node_it == const_nodes.end()) { + GELOGW("Const node: %s not fount, key: %s", const_node->GetName().c_str(), node_key.c_str()); return false; } - if (!IsSameOpNode(const_node, work_node)) { - GELOGI("OpDesc not same: %s %s, parent: %u, index: %zu", - const_node->GetName().c_str(), work_node->GetName().c_str(), parent_index, index); + const auto &work_node = node_it->second; + if (!IsSameConstNode(const_node, work_node)) { + GELOGI("Not same: %s %s, key: %s", const_node->GetName().c_str(), work_node->GetName().c_str(), node_key.c_str()); return false; } } @@ -291,51 +255,34 @@ bool SubgraphConstMigrationPass::IsParallelNodeSame(const map> &graph_datas, - const NodePtr &data_node, uint32_t parent_index) { - bool can_extrapolation = false; - do { - can_extrapolation = false; - const auto &out_anchor = data_node->GetOutControlAnchor(); - const auto &in_anchors = out_anchor->GetPeerInControlAnchors(); - for (size_t i = in_anchors.size(); i > 0; --i) { - const auto &in_anchor = in_anchors.at(i - 1); - const auto &work_node = in_anchor->GetOwnerNode(); - GELOGD("Data: %s, node: %s, parent: %u, index: %zu", - data_node->GetName().c_str(), work_node->GetName().c_str(), parent_index, i); - if (work_node->GetType() != CONSTANT) { - continue; - } - - // Get associated Data, if Data feed other nodes, need append new Data. - map inputs; - map outputs; - if (!GetAssociatedNodes(work_node, inputs, outputs)) { - continue; - } + const map> &all_const_nodes, + map> &all_data_nodes, + const NodePtr &const_node, const string &node_key) { + if (!IsParallelNodeSame(all_const_nodes, const_node, node_key)) { + return SUCCESS; + } - if (!IsParallelNodeSame(graph_datas, work_node, parent_index, i - 1)) { - continue; - } + // Get associated Data, if Data feed other nodes, need append new Data. + uint32_t parent_index = kInvalidParent; + if (!GetAssociatedNodes(all_data_nodes, const_node, parent_index)) { + return SUCCESS; + } - GELOGI("Move node: %s, parent: %u, index: %zu", work_node->GetName().c_str(), parent_index, i); - if (AppendParallelNode(graph_datas, func_node, outputs) != SUCCESS) { - return FAILED; - } + GELOGI("Move node: %s, parent index: %u", const_node->GetName().c_str(), parent_index); + if (AppendParallelNode(func_node, parent_index, all_data_nodes) != SUCCESS) { + return FAILED; + } - if (MoveNodeToParent(graph, func_node, graph_datas, parent_index, i - 1, inputs, outputs) != SUCCESS) { - return FAILED; - } - can_extrapolation = true; - break; - } - } while (can_extrapolation); + if (MoveNodeToParent(graph, func_node, all_const_nodes, all_data_nodes, node_key, parent_index) != SUCCESS) { + return FAILED; + } return SUCCESS; } @@ -343,114 +290,100 @@ Status SubgraphConstMigrationPass::GraphNodeMigration(const ComputeGraphPtr &gra /// /// @ingroup ge /// @brief Append Input Tensor for functional node. -/// @param [in] graph_nodes: Data groups of subgraph. /// @param [in] func_node: functional Node of Case. -/// @param [in] outputs: Parent index of Node output. +/// @param [in/out] parent_index: Parent index for migration. +/// @param [in/out] all_data_nodes: Data groups of subgraph. /// @return 0: SUCCESS / others: FAILED /// -Status SubgraphConstMigrationPass::AppendParallelNode(map> &graph_datas, - const NodePtr &func_node, map &outputs) { +Status SubgraphConstMigrationPass::AppendParallelNode(const NodePtr &func_node, uint32_t &parent_index, + map> &all_data_nodes) { // If outputs index invalid, add Data and Input Tensor. - for (auto &item : outputs) { - if (item.second != kInvalidParent) { - continue; - } - - // Add Data to subgraph. - map append_num; - for (auto &groups : graph_datas) { - const auto &subgraph = groups.first; - auto &data_nodes = groups.second; - - item.second = func_node->GetAllInDataAnchorsSize() + append_num[subgraph]; // Update to valid parent index. - const auto data_name = subgraph->GetName() + "_data_" + std::to_string(item.second); - - OpDescBuilder op_builder(data_name, DATA); - const OpDescPtr op_desc = op_builder.AddInput("x").AddOutput("y").Build(); - if (op_desc == nullptr) { - GELOGE(OUT_OF_MEMORY, "Create multi-batch subgraph data desc failed"); - return OUT_OF_MEMORY; - } + if (parent_index != kInvalidParent) { + return SUCCESS; + } - uint32_t data_index = item.second - kCaseInputBase; - if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index)) { - GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); - return FAILED; - } + // Add Data to subgraph. + parent_index = func_node->GetAllInDataAnchorsSize(); // Update to valid parent index. + for (auto &item : all_data_nodes) { + const auto &subgraph = item.first; + const auto data_name = subgraph->GetName() + "_data_" + std::to_string(parent_index); + OpDescBuilder op_builder(data_name, DATA); + const auto op_desc = op_builder.AddInput("x").AddOutput("y").Build(); + if (op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch subgraph data desc failed"); + return OUT_OF_MEMORY; + } - if (!AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, item.second)) { - GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); - return FAILED; - } + uint32_t data_index = parent_index - kCaseInputBase; + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index)) { + GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); + return FAILED; + } - append_num[subgraph]++; - data_nodes[item.second] = subgraph->AddNode(op_desc); - GELOGI("Add Node: %s, parent index: %u", op_desc->GetName().c_str(), item.second); + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); + return FAILED; } - // Add InputTensor to functional Node. - NodeUtils::AppendInputAnchor(func_node, item.second + 1); + item.second[parent_index] = subgraph->AddNode(op_desc); + GELOGI("Add Node: %s, parent index: %u", op_desc->GetName().c_str(), parent_index); } + // Add InputTensor to functional Node. + NodeUtils::AppendInputAnchor(func_node, parent_index + 1); return SUCCESS; } /// /// @ingroup ge -/// @brief Delete Node from all subgraph. -/// @param [in] graph_nodes: Data groups of subgraph. -/// @param [in] detach: Node will move to parent. -/// @param [in] outputs: Parent index of Node output. +/// @brief Delete Node from subgraph. +/// @param [in] graph: subgraph for process. +/// @param [in] const_node: Node will move to parent. +/// @param [in] data_node: Place holder for Const. /// @return 0: SUCCESS / others: FAILED /// -Status SubgraphConstMigrationPass::DetachParallelNode(const map &graph_datas, const NodePtr &detach, - const map &outputs) { +Status SubgraphConstMigrationPass::DetachParallelNode(const ComputeGraphPtr &graph, const NodePtr &const_node, + const NodePtr &data_node) { // Break Data and Move node. - const auto &in_anchor = detach->GetInControlAnchor(); - const auto &out_anchors = in_anchor->GetPeerOutControlAnchors(); - for (size_t i = out_anchors.size(); i > 0; --i) { - const auto &out_anchor = out_anchors.at(i - 1); + const auto &in_anchor = const_node->GetInControlAnchor(); + const auto out_anchors = in_anchor->GetPeerOutControlAnchors(); + for (const auto out_anchor : out_anchors) { GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); - const auto &owner_node = out_anchor->GetOwnerNode(); - GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), detach->GetName().c_str()); - } - - // Break Move and follow, Link Data and follow. - for (uint32_t i = 0; i < detach->GetAllOutDataAnchorsSize(); ++i) { - auto it_idx = outputs.find(i); - if (it_idx == outputs.end()) { - GELOGE(FAILED, "Node: %s parent index %u not found", detach->GetName().c_str(), i); - return FAILED; - } - - auto it_data = graph_datas.find(it_idx->second); - if (it_data == graph_datas.end()) { - GELOGE(FAILED, "Node: %s parent index %u not found", detach->GetName().c_str(), i); - return FAILED; + const auto owner_node = out_anchor->GetOwnerNode(); + GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), const_node->GetName().c_str()); + if (owner_node->GetInAllNodes().empty() && owner_node->GetOutAllNodes().empty() && owner_node != data_node) { + graph->RemoveNode(owner_node); } + } - const auto &data_node = it_data->second; - const auto &out_anchor = detach->GetOutDataAnchor(i); + const auto &ctrl_anchor = const_node->GetOutControlAnchor(); + const auto ctrl_anchors = ctrl_anchor->GetPeerInControlAnchors(); + for (const auto in_anchor : ctrl_anchors) { + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(ctrl_anchor, in_anchor), "Remove edge failed"); + GELOGI("Remove Edge: %s %s", const_node->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); - const auto &out_desc = detach->GetOpDesc()->GetOutputDesc(i); - const auto &data_desc = data_node->GetOpDesc(); - (void)data_desc->UpdateInputDesc(kDataOutIndex, out_desc); // Set Data Input to new connect Node. - (void)data_desc->UpdateOutputDesc(kDataOutIndex, out_desc); // Set Data Output to new connect Node. + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(data_node->GetOutControlAnchor(), in_anchor), "Add edge failed"); + GELOGI("Add Edge: %s %s", data_node->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); + } - for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - if (in_anchor == nullptr) { - continue; - } - GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); - const auto &owner_node = in_anchor->GetOwnerNode(); - GELOGI("Remove Edge: %s %s", detach->GetName().c_str(), owner_node->GetName().c_str()); + // Break Move and follow, Link Data and follow. + const auto &out_anchor = const_node->GetOutDataAnchor(kZeroIndex); + const auto in_anchors =out_anchor->GetPeerInDataAnchors(); + for (const auto in_anchor : in_anchors) { + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); + GELOGI("Remove Edge: %s %s", const_node->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); - const auto &data_out_anchor = data_node->GetOutDataAnchor(kDataOutIndex); - GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(data_out_anchor, in_anchor), "Add edge failed"); - GELOGI("Add Edge: %s %s", data_node->GetName().c_str(), owner_node->GetName().c_str()); - } + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(data_node->GetOutDataAnchor(kZeroIndex), in_anchor), "Add edge failed"); + GELOGI("Add Edge: %s %s", data_node->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); } + // Update Data op DataType. + const auto &const_desc = const_node->GetOpDesc(); + const auto &tensor_desc = const_desc->GetOutputDesc(kZeroIndex); + const auto &data_desc = data_node->GetOpDesc(); + (void)data_desc->UpdateInputDesc(kZeroIndex, tensor_desc); // Set Data Input to new connect Node. + (void)data_desc->UpdateOutputDesc(kZeroIndex, tensor_desc); // Set Data Output to new connect Node. + return SUCCESS; } @@ -459,47 +392,37 @@ Status SubgraphConstMigrationPass::DetachParallelNode(const map &inputs, - const map &outputs) { - GE_CHECK_NOTNULL(attach); - for (const auto item : inputs) { - if (item.second == kInvalidParent) { // Not connect, Skip. - continue; - } - - const auto &in_anchor = func_node->GetInDataAnchor(item.second); - const auto &out_anchor = in_anchor->GetPeerOutAnchor(); - const auto &owner_node = out_anchor->GetOwnerNode(); - const auto &in_control = attach->GetInControlAnchor(); - GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(owner_node->GetOutControlAnchor(), in_control), "Add edge failed"); - GELOGI("Add Edge: %s %s", owner_node->GetName().c_str(), attach->GetName().c_str()); + const NodePtr &const_node, uint32_t parent_index) { + GE_CHECK_NOTNULL(const_node); + if (parent_index == kInvalidParent) { + return INTERNAL_ERROR; } - for (const auto &item : outputs) { - const auto &func_desc = func_node->GetOpDesc(); - const auto &out_desc = attach->GetOpDesc()->GetOutputDesc(item.second); - (void)func_desc->UpdateInputDesc(item.second, out_desc); // Set Data Input to new connect Node. - - const auto &in_anchor = func_node->GetInDataAnchor(item.second); - const auto &out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor != nullptr) { - GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); - const auto &owner_node = out_anchor->GetOwnerNode(); - GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), func_node->GetName().c_str()); + const auto &func_desc = func_node->GetOpDesc(); + const auto &tensor_desc = const_node->GetOpDesc()->GetOutputDesc(kZeroIndex); + (void)func_desc->UpdateInputDesc(parent_index, tensor_desc); // Set Data Input to new connect Node. + + const auto &in_anchor = func_node->GetInDataAnchor(parent_index); + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor != nullptr) { // Break useless old link. + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); + const auto owner_node = out_anchor->GetOwnerNode(); + GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), func_node->GetName().c_str()); + if (owner_node->GetInAllNodes().empty() && owner_node->GetOutAllNodes().empty()) { + graph->RemoveNode(owner_node); } - GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(attach->GetOutDataAnchor(item.first), in_anchor), "Add edge failed"); - GELOGI("Add Edge: %s %s", attach->GetName().c_str(), func_node->GetName().c_str()); } + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(const_node->GetOutDataAnchor(kZeroIndex), in_anchor), "Add edge failed"); + GELOGI("Add Edge: %s %s, index: %u", const_node->GetName().c_str(), func_node->GetName().c_str(), parent_index); - (void)graph->AddNode(attach); - (void)attach->SetOwnerComputeGraph(graph); - GELOGI("Add Node: %s %s", graph->GetName().c_str(), attach->GetName().c_str()); + (void)graph->AddNode(const_node); + (void)const_node->SetOwnerComputeGraph(graph); + GELOGI("Add Node: %s %s", graph->GetName().c_str(), const_node->GetName().c_str()); return SUCCESS; } @@ -515,43 +438,37 @@ Status SubgraphConstMigrationPass::AttachParallelNode(const ComputeGraphPtr &gra /// @return 0: SUCCESS / others: FAILED /// Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph, const NodePtr &func_node, - const map> &graph_datas, - uint32_t parent_index, uint32_t index, - const map &inputs, - const map &outputs) { - if (inputs.empty()) { + const map> &all_const_nodes, + const map> &all_data_nodes, + const string &node_key, uint32_t parent_index) { + if (node_key.empty() || parent_index == kInvalidParent) { GELOGE(FAILED, "Graph: %s, inputs is empty", graph->GetName().c_str()); return FAILED; } NodePtr move_node; - for (auto &groups : graph_datas) { - const auto &subgraph = groups.first; - const auto &data_nodes = groups.second; - auto it = data_nodes.find(parent_index); - if (it == data_nodes.end()) { - GELOGE(FAILED, "Graph: %s, Data: %u node not found", subgraph->GetName().c_str(), parent_index); + for (auto &item : all_const_nodes) { + const auto &subgraph = item.first; + const auto it_const = item.second.find(node_key); + if (it_const == item.second.end()) { + GELOGE(FAILED, "Graph: %s, Const: %s node not found", subgraph->GetName().c_str(), node_key.c_str()); return FAILED; } + move_node = it_const->second; - const auto &base_data = it->second; - const auto &out_anchor = base_data->GetOutControlAnchor(); - const auto &in_anchors = out_anchor->GetPeerInControlAnchors(); - if (in_anchors.size() <= index || in_anchors.at(index) == nullptr) { - GELOGE(FAILED, "Data: %s, anchor size: %zu, index: %u not found", - base_data->GetName().c_str(), in_anchors.size(), index); + const auto it_nodes = all_data_nodes.find(subgraph); + if (it_nodes == all_data_nodes.end()) { + GELOGE(FAILED, "Graph: %s, Const: %s node not found", subgraph->GetName().c_str(), node_key.c_str()); return FAILED; } - - const auto &in_anchor = in_anchors.at(index); - move_node = in_anchor->GetOwnerNode(); - if (move_node == nullptr) { - GELOGE(FAILED, "Data: %s not found, index: %u", base_data->GetName().c_str(), parent_index); + const auto it_data = it_nodes->second.find(parent_index); + if (it_data == it_nodes->second.end()) { + GELOGE(FAILED, "Graph: %s, Const: %s node not found", subgraph->GetName().c_str(), node_key.c_str()); return FAILED; } - if (DetachParallelNode(data_nodes, move_node, outputs) != SUCCESS) { - GELOGE(FAILED, "Data: %s not found, index: %u", base_data->GetName().c_str(), parent_index); + if (DetachParallelNode(subgraph, move_node, it_data->second) != SUCCESS) { + GELOGE(FAILED, "Data: %s not found, index: %u", move_node->GetName().c_str(), parent_index); return FAILED; } @@ -559,11 +476,10 @@ Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph GELOGI("Remove Node: %s %s", subgraph->GetName().c_str(), move_node->GetName().c_str()); } - if (AttachParallelNode(graph, func_node, move_node, inputs, outputs) != SUCCESS) { + if (AttachParallelNode(graph, func_node, move_node, parent_index) != SUCCESS) { return FAILED; } - migration_append_ = true; return SUCCESS; } } // namespace ge diff --git a/ge/graph/passes/subgraph_const_migration_pass.h b/ge/graph/passes/subgraph_const_migration_pass.h index 3c087852..d93da839 100755 --- a/ge/graph/passes/subgraph_const_migration_pass.h +++ b/ge/graph/passes/subgraph_const_migration_pass.h @@ -36,50 +36,54 @@ class SubgraphConstMigrationPass : public GraphPass { private: /// /// @ingroup ge - /// @brief Get all Data nodes for all subgraph. + /// @brief Get all Const/Data nodes for all subgraph. /// @param [in] graph: Root compute graph. /// @param [in] func_desc: functional OpDesc of Case. - /// @param [out] graph_datas: Data groups of subgraph. + /// @param [out] all_const_nodes: Const groups of subgraph. + /// @param [out] all_data_nodes: Data groups of subgraph. /// @return 0: SUCCESS / others: FAILED /// - Status ClassifyDataNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, - map> &graph_datas); + Status ClassifyGraphNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, + map> &all_const_nodes, + map> &all_data_nodes); /// /// @ingroup ge - /// @brief Get all Data nodes for all subgraph. - /// @param [in] node: Const node of subgraph. - /// @param [in] func_desc: functional OpDesc of Case. - /// @param [out] graph_nodes: Data groups of subgraph. + /// @brief Get parent_index for Const node migration. + /// @param [in] all_data_nodes: Data groups of subgraph. + /// @param [in] const_node: Const node will process. + /// @param [out] parent_index: parent index for replace Data. /// @return true: SUCCESS / false: FAILED /// - bool GetAssociatedNodes(const NodePtr &node, map &inputs, map &outputs); + bool GetAssociatedNodes(const map> &all_data_nodes, + const NodePtr &const_node, uint32_t &parent_index); /// /// @ingroup ge - /// @brief Get all Data nodes for all subgraph. - /// @param [in] graph_nodes: Data groups of subgraph. - /// @param [in] data_base: Data Node for migration. - /// @param [in] data_idx: Data groups of subgraph. - /// @param [in] data_idx: Data groups of subgraph. + /// @brief Check parallel node is same for all subgraph. + /// @param [in] all_const_nodes: Const groups of subgraph. + /// @param [in] const_node: Const Node for migration. + /// @param [in] node_key: Key of Const node. /// @return true: Same / false: not same /// - bool IsParallelNodeSame(const map> &graph_nodes, - const NodePtr &const_node, uint32_t parent_index, size_t index); + bool IsParallelNodeSame(const map> &all_const_nodes, + const NodePtr &const_node, const string &node_key); /// /// @ingroup ge /// @brief Migration subgraph Node to Root /// @param [in] graph: Root compute graph. /// @param [in] func_node: functional Node of Case. - /// @param [in] graph_nodes: Data groups of subgraph. - /// @param [in] data_base: Data Node for migration. - /// @param [in] data_idx: Data groups of subgraph. + /// @param [in] all_const_nodes: Const groups of subgraph. + /// @param [in] all_data_nodes: Data groups of subgraph. + /// @param [in] const_node: Const Node for migration. + /// @param [in] node_key: Key of Const node for migration. /// @return 0: SUCCESS / others: FAILED /// Status GraphNodeMigration(const ComputeGraphPtr &graph, const NodePtr &func_node, - map> &graph_nodes, - const NodePtr &data_base, uint32_t data_idx); + const map> &all_const_nodes, + map> &all_data_nodes, + const NodePtr &const_node, const string &node_key); /// /// @ingroup ge @@ -93,46 +97,42 @@ class SubgraphConstMigrationPass : public GraphPass { /// @return 0: SUCCESS / others: FAILED /// Status MoveNodeToParent(const ComputeGraphPtr &graph, const NodePtr &func_node, - const map> &graph_nodes, - uint32_t parent_index, uint32_t anchor_idx, - const map &inputs, const map &outputs); + const map> &all_const_nodes, + const map> &all_data_nodes, + const string &node_key, uint32_t parent_index); /// /// @ingroup ge /// @brief Append Input Tensor for functional node. - /// @param [in] graph_nodes: Data groups of subgraph. - /// @param [in] func_node: functional Node of Case. - /// @param [in] outputs: Parent index of Node output. + /// @param [in] graph_nodes: Const groups of subgraph. + /// @param [in/out] parent_index: Parent index for migration. + /// @param [in/out] all_data_nodes: Data groups of subgraph. /// @return 0: SUCCESS / others: FAILED /// - Status AppendParallelNode(map> &graph_nodes, - const NodePtr &func_node, map &outputs); + Status AppendParallelNode(const NodePtr &func_node, uint32_t &parent_index, + map> &all_data_nodes); /// /// @ingroup ge - /// @brief Delete Node from all subgraph. - /// @param [in] graph_nodes: Data groups of subgraph. - /// @param [in] detach: Node will move to parent. - /// @param [in] outputs: Parent index of Node output. + /// @brief Delete Node from subgraph. + /// @param [in] graph: subgraph for process. + /// @param [in] const_node: Node will move to parent. + /// @param [in] data_node: Place holder for Const. /// @return 0: SUCCESS / others: FAILED /// - Status DetachParallelNode(const map &graph_datas, const NodePtr &detach, - const map &outputs); + Status DetachParallelNode(const ComputeGraphPtr &graph, const NodePtr &const_node, const NodePtr &data_node); /// /// @ingroup ge /// @brief Move Node to Parent Graph. /// @param [in] graph: Parent compute graph. /// @param [in] func_node: functional Node of Case. - /// @param [in] attach: Node will move to parent. - /// @param [in] inputs: Parent index of Node input. - /// @param [in] outputs: Parent index of Node output. + /// @param [in] const_node: Node will move to parent. + /// @param [in] parent_index: Parent index of Node input. /// @return 0: SUCCESS / others: FAILED /// - Status AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, const NodePtr &attach, - const map &inputs, const map &outputs); - - bool migration_append_{false}; + Status AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, + const NodePtr &const_node, uint32_t parent_index); }; } // namespace ge #endif // GE_COMMON_SUBGRAPH_CONST_MIGRATION_H_ \ No newline at end of file