| @@ -28,6 +28,7 @@ using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const size_t kGenMaskInputIndex = 1; | |||||
| const size_t kDefaultMaxParallelNum = 1; | const size_t kDefaultMaxParallelNum = 1; | ||||
| } // namespace | } // namespace | ||||
| @@ -104,12 +105,14 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node | |||||
| } | } | ||||
| auto in_data_nodes = node->GetInDataNodes(); | auto in_data_nodes = node->GetInDataNodes(); | ||||
| for (auto &in_data_node : in_data_nodes) { | |||||
| // node gen_mask is located at different place in the fused node | |||||
| if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { | |||||
| continue; | |||||
| if (in_data_nodes.size() > kGenMaskInputIndex) { | |||||
| NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); | |||||
| for (auto &in_data_node : in_data_nodes) { | |||||
| // node gen_mask is located at different place in the fused node | |||||
| if (in_data_node->GetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) { | |||||
| gen_mask = in_data_node; | |||||
| } | |||||
| } | } | ||||
| NodePtr &gen_mask = in_data_node; | |||||
| if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { | if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { | ||||
| continue; | continue; | ||||