diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 2788bc43..8dfd447d 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -28,6 +28,7 @@ using std::vector; namespace ge { namespace { +const size_t kGenMaskInputIndex = 1; const size_t kDefaultMaxParallelNum = 1; } // namespace @@ -104,12 +105,14 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetInDataNodes(); - for (auto &in_data_node : in_data_nodes) { - // node gen_mask is located at different place in the fused node - if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { - continue; + if (in_data_nodes.size() > kGenMaskInputIndex) { + NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); + for (auto &in_data_node : in_data_nodes) { + // node gen_mask is located at different place in the fused node + if (in_data_node->GetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) { + gen_mask = in_data_node; + } } - NodePtr &gen_mask = in_data_node; if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { continue;