Browse Source

add netoupt alwways

tags/v1.3.0
gengchao4@huawei.com 3 years ago
parent
commit
dbafeb8531
4 changed files with 14 additions and 4 deletions
  1. +3
    -1
      ge/graph/build/task_generator.cc
  2. +9
    -3
      ge/graph/passes/net_output_pass.cc
  3. +1
    -0
      ge/host_cpu_engine/ops_kernel_store/op/host_op.cc
  4. +1
    -0
      ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc

+ 3
- 1
ge/graph/build/task_generator.cc View File

@@ -795,7 +795,9 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const
GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId()); GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId());
} }


GE_CHECK_NOTNULL(bp_op_desc);
if (bp_op_desc == nullptr) {
return last_bp;
}
uint32_t current_idx = 0; uint32_t current_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
OpDescPtr op_desc = node->GetOpDesc(); OpDescPtr op_desc = node->GetOpDesc();


+ 9
- 3
ge/graph/passes/net_output_pass.cc View File

@@ -40,6 +40,7 @@ static std::map<std::string, ge::DataType> output_type_str_to_datatype = {


// the size of user defined output datatype or format string after split by ":". // the size of user defined output datatype or format string after split by ":".
const size_t kUserDefinedElementCount = 2; const size_t kUserDefinedElementCount = 2;
const size_t kNodesCount = 2;


Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node, Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node,
std::map<int32_t, RetvalInfo> &retval_node_index_map) { std::map<int32_t, RetvalInfo> &retval_node_index_map) {
@@ -424,11 +425,13 @@ Status NetOutputPass::AddCtrlEdgesBetweenLeafAndNetOutput(const ge::ComputeGraph
GELOGI("No need to add ctrl edge to netoutput because user out nodes have been set."); GELOGI("No need to add ctrl edge to netoutput because user out nodes have been set.");
return SUCCESS; return SUCCESS;
} }
bool graph_has_only_one_node_except_netoutput = (graph->GetDirectNodesSize() == kNodesCount);
for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
if (node == nullptr || node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() == NETOUTPUT) { if (node == nullptr || node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() == NETOUTPUT) {
continue; continue;
} }
if ((node->GetInControlNodes().size() != 0 || node->GetInDataNodes().size() != 0) &&
if ((node->GetInControlNodes().size() != 0 || node->GetInDataNodes().size() != 0 ||
graph_has_only_one_node_except_netoutput) &&
node->GetOutDataNodesSize() == 0 && node->GetOutControlNodes().size() == 0) { node->GetOutDataNodesSize() == 0 && node->GetOutControlNodes().size() == 0) {
GE_CHK_STATUS_RET(GraphUtils::AddEdge(node->GetOutControlAnchor(), net_out_node->GetInControlAnchor()), GE_CHK_STATUS_RET(GraphUtils::AddEdge(node->GetOutControlAnchor(), net_out_node->GetInControlAnchor()),
"add edge failed"); "add edge failed");
@@ -493,10 +496,13 @@ Status NetOutputPass::AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph,
} }
GELOGI("[NETOUTPUT PASS] OutNodesInfo size:%zu, Targets Size:%zu, is_include_special_node_:%d", GELOGI("[NETOUTPUT PASS] OutNodesInfo size:%zu, Targets Size:%zu, is_include_special_node_:%d",
graph->GetGraphOutNodesInfo().size(), graph->GetGraphTargetNodesInfo().size(), is_include_special_node_); graph->GetGraphOutNodesInfo().size(), graph->GetGraphTargetNodesInfo().size(), is_include_special_node_);
// If user does not set out nodes and targets and no retval node, return false
// If user does not set out nodes and targets and no retval node, also add netoutput node
if ((graph->GetGraphOutNodesInfo().empty()) && (graph->GetGraphTargetNodesInfo().empty()) && if ((graph->GetGraphOutNodesInfo().empty()) && (graph->GetGraphTargetNodesInfo().empty()) &&
!is_include_special_node_) { !is_include_special_node_) {
GELOGI("[NETOUTPUT PASS] output_nodes and target_nodes and special nodes is empty!It means no need netoutput!");
GELOGI("[NETOUTPUT PASS] output_nodes and target_nodes and special nodes is empty!Add netoutput!");
output_node = graph->AddNode(net_output_desc);
GE_CHK_STATUS_RET(AddCtrlEdgesBetweenLeafAndNetOutput(graph, output_node),
"add ctrl edge between leaf and netoutput failed");
return SUCCESS; return SUCCESS;
} }
GELOGI("[NETOUTPUT PASS] Output node size:%lu.", output_nodes_info.size()); GELOGI("[NETOUTPUT PASS] Output node size:%lu.", output_nodes_info.size());


+ 1
- 0
ge/host_cpu_engine/ops_kernel_store/op/host_op.cc View File

@@ -35,5 +35,6 @@ REGISTER_OP_CREATOR(Mul, HostOp);
REGISTER_OP_CREATOR(ConcatV2, HostOp); REGISTER_OP_CREATOR(ConcatV2, HostOp);
REGISTER_OP_CREATOR(Data, HostOp); REGISTER_OP_CREATOR(Data, HostOp);
REGISTER_OP_CREATOR(Fill, HostOp); REGISTER_OP_CREATOR(Fill, HostOp);
REGISTER_OP_CREATOR(NetOutput, HostOp);
} // namespace host_cpu } // namespace host_cpu
} // namespace ge } // namespace ge

+ 1
- 0
ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc View File

@@ -28,6 +28,7 @@ Status NoOpKernel::Compute(TaskContext& context) {
} }


REGISTER_KERNEL_CREATOR(NoOp, NoOpKernel); REGISTER_KERNEL_CREATOR(NoOp, NoOpKernel);
REGISTER_KERNEL_CREATOR(NetOutput, NoOpKernel);
} // namespace host_cpu } // namespace host_cpu
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

Loading…
Cancel
Save