diff --git a/ge/graph/passes/atomic_addr_clean_pass.cc b/ge/graph/passes/atomic_addr_clean_pass.cc index 18cac856..b62f86c7 100755 --- a/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/ge/graph/passes/atomic_addr_clean_pass.cc @@ -33,10 +33,12 @@ namespace ge { Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); GELOGD("AtomicAddrCleanPass begin."); + bool is_unknown_graph = graph->GetGraphUnknownFlag(); + // 1.Recoginze atomic and loop mark vector atomic_node_vec; for (NodePtr &node : graph->GetDirectNode()) { - if (IsAtomicOp(node)) { + if (IsAtomicOp(node, is_unknown_graph)) { atomic_node_vec.push_back(node); } if (!is_loop_graph_ && node->GetType() == LOOPCOND) { @@ -50,7 +52,6 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { return SUCCESS; } - bool is_unknown_graph = graph->GetGraphUnknownFlag(); if (is_unknown_graph) { GELOGD("Graph[%s] is unknown graph. It will call fe interface to compile op.", graph->GetName().c_str()); GE_CHK_STATUS_RET(CompileUnknownGraphOp(atomic_node_vec)); @@ -242,7 +243,7 @@ Status AtomicAddrCleanPass::LinkToAtomicNode(const NodePtr &atomic_node, NodePtr return SUCCESS; } -bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { +bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node, bool is_unknown_graph) { GE_IF_BOOL_EXEC(node == nullptr, GELOGE(FAILED, "node is null."); return false); OpDescPtr op_desc = node->GetOpDesc(); if (op_desc == nullptr) { @@ -262,19 +263,21 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { return false; } - if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) { - std::vector atomic_output_index; - (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); - bool is_all_output_peer_also_atomic = true; - for (const auto &output_index : atomic_output_index) { - if (!IsOutputIndexPeerInputAtomic(node, output_index)) { - is_all_output_peer_also_atomic = false; - break; + if (!is_unknown_graph) { + if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) { + std::vector atomic_output_index; + (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); + bool is_all_output_peer_also_atomic = true; + for (const auto &output_index : atomic_output_index) { + if (!IsOutputIndexPeerInputAtomic(node, output_index)) { + is_all_output_peer_also_atomic = false; + break; + } + } + if (is_all_output_peer_also_atomic) { + GELOGI("all out peer node input atomic, skip this out atomic process, node name: %s", node->GetName().c_str()); + return false; } - } - if (is_all_output_peer_also_atomic) { - GELOGI("all out peer node input atomic, skip this out atomic process, node name: %s", node->GetName().c_str()); - return false; } } @@ -342,6 +345,7 @@ bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int6 Status AtomicAddrCleanPass::ClearStatus() { hcom_node_vec_.clear(); return SUCCESS; + } Status AtomicAddrCleanPass::CompileUnknownGraphOp(const vector &atomic_node_vec) { diff --git a/ge/graph/passes/atomic_addr_clean_pass.h b/ge/graph/passes/atomic_addr_clean_pass.h index 420ddd01..64bc604b 100755 --- a/ge/graph/passes/atomic_addr_clean_pass.h +++ b/ge/graph/passes/atomic_addr_clean_pass.h @@ -72,7 +72,7 @@ class AtomicAddrCleanPass : public GraphPass { * @param node * @return */ - bool IsAtomicOp(const NodePtr &node); + bool IsAtomicOp(const NodePtr &node, bool is_unknown_graph); /** * Handle atomic node in unknown graph