| @@ -28,10 +28,6 @@ using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| <<<<<<< Updated upstream | |||||
| const size_t kGenMaskInputIndex = 1; | |||||
| ======= | |||||
| >>>>>>> Stashed changes | |||||
| const size_t kDefaultMaxParallelNum = 1; | const size_t kDefaultMaxParallelNum = 1; | ||||
| } // namespace | } // namespace | ||||
| @@ -108,10 +104,6 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node | |||||
| } | } | ||||
| auto in_data_nodes = node->GetInDataNodes(); | auto in_data_nodes = node->GetInDataNodes(); | ||||
| <<<<<<< Updated upstream | |||||
| if (in_data_nodes.size() > kGenMaskInputIndex) { | |||||
| NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); | |||||
| ======= | |||||
| for (auto &in_data_node : in_data_nodes) { | for (auto &in_data_node : in_data_nodes) { | ||||
| // node gen_mask is located at different place in the fused node | // node gen_mask is located at different place in the fused node | ||||
| if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { | if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { | ||||
| @@ -119,7 +111,6 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node | |||||
| } | } | ||||
| NodePtr &gen_mask = in_data_node; | NodePtr &gen_mask = in_data_node; | ||||
| >>>>>>> Stashed changes | |||||
| if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { | if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { | ||||
| continue; | continue; | ||||
| } | } | ||||