|
- /**
- * Copyright 2019-2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "graph/tuning_utils.h"
- #include "../debug/ge_util.h"
- #include "../debug/ge_op_types.h"
-
- namespace ge {
- const std::string peer_node_name_attr = "_peerNodeName";
- const std::string parent_node_name_attr = "_parentNodeName";
- const std::string alias_name_attr = "_aliasName";
- const std::string parent_node_attr = "parentNode";
- const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex";
- const std::string tuning_subgraph_prefix = "/aicore_subgraph_";
- const std::string non_tuning_subgraph_prefix = "/subgraph_";
- const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END};
- const std::set<std::string> kExeTypes = {DATA, NETOUTPUT};
- NodeNametoNodeNameMap TuningUtils::data_2_netoutput_;
- NodetoNodeNameMap TuningUtils::data_node_2_netoutput_;
- NodetoNodeMap TuningUtils::data_node_2_netoutput_node_;
- NodeSet TuningUtils::netoutput_nodes_;
- NodeSet TuningUtils::merged_graph_nodes_;
- SubgraphCreateOutNode TuningUtils::create_output_;
- std::mutex TuningUtils::mutex_;
-
- std::string TuningUtils::PrintCheckLog() {
- std::stringstream ss;
- ss << "d2n:{";
- for (const auto &pair : data_2_netoutput_) {
- ss << "data:" << pair.first << "-"
- << "netoutput:" << pair.second;
- ss << " | ";
- }
- ss << "}";
- ss << "netoutputs:{";
- for (const auto &node : netoutput_nodes_) {
- ss << "netoutput:" << node->GetName();
- ss << " | ";
- }
- ss << "}";
- return ss.str();
- }
-
- std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) {
- if (anchor == nullptr) {
- GELOGE(GRAPH_FAILED, "Anchor is nullptr");
- return "Null";
- }
- auto node = anchor->GetOwnerNode();
- return node == nullptr ? "Null" : node->GetName();
- }
-
- // part 1
- graphStatus TuningUtils::ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs,
- std::vector<ComputeGraphPtr> non_tuning_subgraphs, bool exe_flag,
- const std::string &path, const std::string &user_path) {
- int64_t i = 0;
- int64_t j = 0;
- std::lock_guard<std::mutex> lock(mutex_);
- for (auto &subgraph : tuning_subgraphs) {
- create_output_.emplace(subgraph, nullptr);
- auto help_info = HelpInfo{i, exe_flag, true, path, user_path};
- if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
- GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i);
- return GRAPH_FAILED;
- }
- i++;
- }
-
- for (auto &subgraph : non_tuning_subgraphs) {
- create_output_.emplace(subgraph, nullptr);
- auto help_info = HelpInfo{j, true, false, path, user_path};
- if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
- GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j);
- return GRAPH_FAILED;
- }
- j++;
- }
- create_output_.clear();
- return SUCCESS;
- }
-
- // +---------------+
- // | pld pld |
- // | \ / |
- // | relu relu |
- // | \ / |
- // | add |
- // | | |
- // | end |
- // +---------------+
- // |
- // |
- // V
- // +---------------+
- // | data data |
- // | \ / |
- // | relu relu |
- // | \ / |
- // | add |
- // | | |
- // | netoutput |
- // +---------------+
- graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) {
- GE_CHECK_NOTNULL(exe_graph);
- // if not make exe, just dump and return
- if (!help_info.exe_flag) {
- DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
- GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index);
- return SUCCESS;
- }
- // modify sub graph
- for (NodePtr &node : exe_graph->GetDirectNode()) {
- // 1.handle pld
- if (node->GetType() == PLACEHOLDER) {
- if (HandlePld(node) != SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
- exe_graph->GetName().c_str());
- return FAILED;
- }
- }
- // 2.handle end
- if (node->GetType() == END) {
- if (HandleEnd(node) != SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
- exe_graph->GetName().c_str());
- return FAILED;
- }
- }
- }
- graphStatus ret = exe_graph->TopologicalSorting();
- if (ret != SUCCESS) {
- GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret);
- return ret;
- }
- // dump subgraphs which modified by us
- if (help_info.user_path.empty()) {
- DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
- } else {
- GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path);
- }
- return SUCCESS;
- }
-
- void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path) {
- if (!path.empty()) {
- if (is_tuning_graph) {
- GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
- } else {
- GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
- }
- } else {
- path = "./";
- if (is_tuning_graph) {
- GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
- } else {
- GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
- }
- }
- }
-
- graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) {
- auto graph = node->GetOwnerComputeGraph();
- GE_CHECK_NOTNULL(graph);
- auto data_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), DATA);
- GE_CHECK_NOTNULL(data_op_desc);
- auto pld_op_desc = node->GetOpDesc();
- GE_CHECK_NOTNULL(pld_op_desc);
- auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data
- // data inputdesc & outputdesc set as same
- if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) {
- GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
- return FAILED;
- }
- if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) {
- GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
- return FAILED;
- }
- data_node = graph->AddNode(data_op_desc);
- GE_CHECK_NOTNULL(data_node);
- if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
- return FAILED;
- }
- return SUCCESS;
- }
-
- graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) {
- auto op_desc = data_node->GetOpDesc();
- GE_CHECK_NOTNULL(op_desc);
-
- auto pld_desc = pld->GetOpDesc();
- GE_CHECK_NOTNULL(pld_desc);
- // inherit
- // a. set `end's input node type` as attr
- std::string parent_op_type;
- if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) {
- GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str());
- return FAILED;
- }
- (void)AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type);
- // b. set `end's input node name` as attr
- std::string parent_op_name;
- if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) {
- GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str());
- return FAILED;
- }
- (void)AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name);
- // c. set `end's input node's out anchor index` as attr
- int parent_node_anchor_index;
- if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) {
- GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str());
- return FAILED;
- }
- (void)AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index);
- GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(),
- data_node->GetName().c_str(), data_node->GetType().c_str());
- // d. set `end node name` as attr
- std::string peer_end_name;
- if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) {
- GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str());
- return FAILED;
- }
- (void)AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name);
- GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(),
- data_node->GetName().c_str(), data_node->GetType().c_str());
- return SUCCESS;
- }
-
- graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) {
- auto type_pld = node->GetType();
- auto type_data = data_node->GetType();
- if (type_pld != PLACEHOLDER || type_data != DATA) {
- GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), type_pld.c_str(),
- type_data.c_str());
- return FAILED;
- }
- auto graph = node->GetOwnerComputeGraph();
- GE_CHECK_NOTNULL(graph);
- std::vector<int> output_map(node->GetAllOutDataAnchorsSize());
- for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
- output_map[i] = static_cast<int>(i);
- }
-
- auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map);
- if (ret != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", node->GetName().c_str(),
- data_node->GetName().c_str(), ret);
- return FAILED;
- }
-
- NodeUtils::UnlinkAll(*node);
-
- ret = GraphUtils::RemoveNodeWithoutRelink(graph, node);
- if (ret != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
- return FAILED;
- }
-
- GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", node->GetName().c_str(),
- node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str());
- return ret;
- }
-
- graphStatus TuningUtils::HandlePld(NodePtr &node) {
- GE_CHECK_NOTNULL(node);
- auto graph = node->GetOwnerComputeGraph();
- GE_CHECK_NOTNULL(graph);
- NodePtr data_node = nullptr;
-
- // 1. create data node
- if (CreateDataNode(node, data_node) != SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
- return FAILED;
- }
- // 2. add necessary info to data_node for recovery whole graph
- if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
- return FAILED;
- }
- // 3. replace pld node by data node created before
- if (ChangePld2Data(node, data_node) != SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
- return FAILED;
- }
- GELOGD("TUU:pld[%s] handle success", node->GetName().c_str());
- return SUCCESS;
- }
-
- graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) {
- GE_CHECK_NOTNULL(node);
- auto graph = node->GetOwnerComputeGraph();
- GE_CHECK_NOTNULL(graph);
- auto search = create_output_.find(graph);
- if (search == create_output_.end()) {
- GELOGE(FAILED, "TUU:node %s's owner sub graph %s not exist in create_output map", node->GetName().c_str(),
- graph->GetName().c_str());
- return FAILED;
- }
- if (search->second != nullptr) {
- out_node = search->second;
- GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str());
- return SUCCESS;
- }
- auto out_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), NETOUTPUT);
- GE_CHECK_NOTNULL(out_op_desc);
- out_node = graph->AddNode(out_op_desc);
- GE_CHECK_NOTNULL(out_node);
- if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
- return FAILED;
- }
- create_output_[graph] = out_node;
- return SUCCESS;
- }
-
- graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) {
- GE_CHECK_NOTNULL(end);
- GE_CHECK_NOTNULL(out_node);
- auto op_desc = out_node->GetOpDesc();
- GE_CHECK_NOTNULL(op_desc);
- std::vector<std::string> alias_names = {};
- (void)AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names);
- alias_names.push_back(end->GetName());
- (void)AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names);
- return SUCCESS;
- }
-
- graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
- GE_CHECK_NOTNULL(end_node);
- GE_CHECK_NOTNULL(out_node);
- // get end in node is control node or normal node
- AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
- ? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor())
- : Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0));
- auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1
- if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
- GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
- GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), end_node->GetName().c_str(),
- end_node->GetOwnerComputeGraph()->GetName().c_str());
- return FAILED;
- }
- // add edge between `end in node` and `out_node`
- if (src_anchor->IsTypeOf<OutDataAnchor>()) {
- std::shared_ptr<InDataAnchor> anchor =
- ComGraphMakeShared<InDataAnchor>(out_node, out_node->GetAllInDataAnchors().size());
- GE_CHECK_NOTNULL(anchor);
- out_node->in_data_anchors_.push_back(anchor);
- if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
- GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
- GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(),
- end_node->GetOwnerComputeGraph()->GetName().c_str());
- return FAILED;
- }
- auto end_op_desc = end_node->GetOpDesc();
- GE_CHECK_NOTNULL(end_op_desc);
- auto out_node_op_desc = out_node->GetOpDesc();
- GE_CHECK_NOTNULL(out_node_op_desc);
- // end node always has one input
- if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str());
- return FAILED;
- }
- } else if (src_anchor->IsTypeOf<OutControlAnchor>()) {
- auto anchor = out_node->GetInControlAnchor();
- if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
- GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
- GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(),
- end_node->GetOwnerComputeGraph()->GetName().c_str());
- return FAILED;
- }
- } else {
- GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(),
- end_node->GetOwnerComputeGraph()->GetName().c_str());
- return FAILED;
- }
-
- return SUCCESS;
- }
-
- graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
- GE_CHECK_NOTNULL(end_node);
- GE_CHECK_NOTNULL(out_node);
- auto type_end = end_node->GetType();
- auto type_out = out_node->GetType();
- if (type_end != END || type_out != NETOUTPUT) {
- GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", end_node->GetName().c_str(),
- type_end.c_str(), type_out.c_str());
- return FAILED;
- }
- // link all `end nodes's in node` to this out_node
- if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) {
- GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str());
- return FAILED;
- }
- // remove `end node`
- NodeUtils::UnlinkAll(*end_node);
- auto graph = end_node->GetOwnerComputeGraph();
- GE_CHECK_NOTNULL(graph);
- if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) {
- GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str());
- return FAILED;
- }
- return SUCCESS;
- }
-
- graphStatus TuningUtils::HandleEnd(NodePtr &node) {
- GE_CHECK_NOTNULL(node);
- auto graph = node->GetOwnerComputeGraph();
- GE_CHECK_NOTNULL(graph);
- NodePtr out_node = nullptr;
-
- // 1. create net_output node , add only one NetOutput node to one subgraph
- if (CreateNetOutput(node, out_node) != SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
- return FAILED;
- }
- // 2. add necessary info to out_node for recovery whole graph
- if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
- return FAILED;
- }
- // 3. replace all end nodes by one output node created before
- if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
- return FAILED;
- }
- GELOGD("TUU:end[%s] handle success", node->GetName().c_str());
- return SUCCESS;
- }
-
- // part 2
- graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) {
- // 1. get all subgraph object
- std::vector<ComputeGraphPtr> graphs;
- // options format like {index:"subgraph_path"}
- for (const auto &pair : options) {
- ComputeGraphPtr compute_graph = ComGraphMakeShared<ComputeGraph>(std::to_string(pair.first));
- if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) {
- GELOGE(FAILED, "TUU:load graph from file failed");
- }
- graphs.push_back(compute_graph);
- }
- // 2. merge graph
- ComputeGraphPtr merged_graph = ComGraphMakeShared<ComputeGraph>("whole_graph_after_tune");
- GE_CHECK_NOTNULL(merged_graph);
- if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) {
- GELOGE(FAILED, "TUU:MergeGraph failed");
- return FAILED;
- }
- // 3. set parent graph
- for (const auto &node : merged_graph->GetDirectNode()) {
- GE_CHECK_NOTNULL(node);
- if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str());
- return FAILED;
- }
- }
- graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph);
- return SUCCESS;
- }
-
- // +----------------------------------+
- // | const const |
- // | \ / |
- // | netoutput(end,end) |
- // +----------------------------------+
- // +
- // +----------------------------------+
- // | data(pld) data(pld) |
- // | \ / |
- // | relu relu |
- // | \ / |
- // | \ / |
- // | add |
- // | | |
- // | netoutput(end) |
- // +----------------------------------+
- // +
- // +----------------------------------+
- // | data(pld) |
- // | / |
- // | netoutput |
- // +----------------------------------+
- // |
- // |
- // V
- // +----------------------------------+
- // | const const |
- // | \ / |
- // | relu relu |
- // | \ / |
- // | \ / |
- // | add |
- // | | |
- // | netoutput |
- // +----------------------------------+
- graphStatus TuningUtils::MergeAllSubGraph(std::vector<ComputeGraphPtr> &subgraphs,
- ComputeGraphPtr &output_merged_compute_graph) {
- GE_CHECK_NOTNULL(output_merged_compute_graph);
- // 1. handle all subgraphs
- for (auto &subgraph : subgraphs) {
- Status ret_status = MergeSubGraph(subgraph);
- if (ret_status != SUCCESS) {
- GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str());
- return ret_status;
- }
- }
-
- for (const auto &node : merged_graph_nodes_) {
- (void)output_merged_compute_graph->AddNode(node);
- GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str());
- }
-
- // 2. remove data and output node added by us
- if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str());
- return FAILED;
- }
- graphStatus ret = output_merged_compute_graph->TopologicalSorting();
- if (ret != SUCCESS) {
- GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret);
- return ret;
- }
- GELOGD("TUU:Print-%s", PrintCheckLog().c_str());
- GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str());
- return SUCCESS;
- }
-
- graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) {
- for (auto &node : subgraph->GetDirectNode()) {
- if (kPartitionOpTypes.count(node->GetType()) > 0) {
- GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type");
- return FAILED;
- }
- // handle data converted from pld node
- if (node->GetType() == DATA) {
- auto op_desc = node->GetOpDesc();
- GE_CHECK_NOTNULL(op_desc);
- std::string peer_out_name;
- bool has_valid_str = (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty());
- if (has_valid_str) {
- std::lock_guard<std::mutex> lock(mutex_);
- data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name);
- data_node_2_netoutput_.emplace(node, peer_out_name);
- continue;
- }
- }
- // handle netoutput converted from end node
- if (node->GetType() == NETOUTPUT) {
- auto op_desc = node->GetOpDesc();
- GE_CHECK_NOTNULL(op_desc);
- std::vector<string> out_alias_name;
- bool has_valid_str =
- (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty());
- if (has_valid_str) {
- std::lock_guard<std::mutex> lock(mutex_);
- netoutput_nodes_.insert(node);
- }
- }
- {
- std::lock_guard<std::mutex> lock(mutex_);
- merged_graph_nodes_.emplace(node);
- }
- GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str());
- }
- GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str());
- return SUCCESS;
- }
-
- graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) {
- GE_CHECK_NOTNULL(graph);
- // 1. traverse
- for (auto &pair : data_node_2_netoutput_) {
- auto data_node = pair.first;
- GE_CHECK_NOTNULL(data_node);
- auto netoutput_name = pair.second;
- auto netoutput_node = graph->FindNode(netoutput_name);
- GE_CHECK_NOTNULL(netoutput_node);
- data_node_2_netoutput_node_.emplace(data_node, netoutput_node);
- // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor`
- AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
- ? Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutControlAnchor())
- : Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutDataAnchor(0));
- AnchorPtr net_output_in_anchor = nullptr;
- AnchorPtr src_out_anchor = nullptr;
- if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed",
- netoutput_node->GetName().c_str(), data_node->GetName().c_str());
- return FAILED;
- }
- // 3. relink
- if (GraphUtils::RemoveEdge(src_out_anchor, net_output_in_anchor) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
- GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(),
- GetNodeNameByAnchor(net_output_in_anchor.get()).c_str(), net_output_in_anchor->GetIdx(),
- data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
- return FAILED;
- }
- GE_CHECK_NOTNULL(data_out_anchor);
- for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) {
- if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
- GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(),
- GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
- data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
- return FAILED;
- }
- if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
- GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(),
- GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
- data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
- return FAILED;
- }
- }
- }
- // 4. remove out nodes added by us
- for (auto &node : netoutput_nodes_) {
- NodeUtils::UnlinkAll(*node);
- if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
- return FAILED;
- }
- GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str());
- }
- return SUCCESS;
- }
-
- graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor,
- AnchorPtr &src_out_anchor) {
- // 1. get `data parent node name`, i.e. `netoutput input node name`
- std::string netoutput_input_name;
- auto op_desc = data_node->GetOpDesc();
- GE_CHECK_NOTNULL(op_desc);
- if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) {
- GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str());
- return FAILED;
- }
- // 2. find index
- int parent_node_anchor_index;
- if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) {
- GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str());
- return FAILED;
- }
- // 3.find in data or ctrl anchor by 1&2 step
- for (auto &in_anchor : out_node->GetAllInAnchors()) {
- GE_CHECK_NOTNULL(in_anchor);
- for (auto &src_anchor : in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl
- GE_CHECK_NOTNULL(src_anchor);
- auto src_node = src_anchor->GetOwnerNode();
- GE_CHECK_NOTNULL(src_node);
- if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) {
- dest_in_anchor = in_anchor;
- src_out_anchor = src_anchor;
- GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s",
- out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(),
- parent_node_anchor_index, data_node->GetName().c_str());
- break;
- }
- }
- }
- GE_CHECK_NOTNULL(dest_in_anchor);
- GE_CHECK_NOTNULL(src_out_anchor);
- return SUCCESS;
- }
-
- } // namespace ge
|