|
@@ -28,6 +28,7 @@ using std::vector; |
|
|
|
|
|
|
|
|
namespace ge { |
|
|
namespace ge { |
|
|
namespace { |
|
|
namespace { |
|
|
|
|
|
const size_t kGenMaskInputIndex = 1; |
|
|
const size_t kDefaultMaxParallelNum = 1; |
|
|
const size_t kDefaultMaxParallelNum = 1; |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
@@ -104,12 +105,14 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto in_data_nodes = node->GetInDataNodes(); |
|
|
auto in_data_nodes = node->GetInDataNodes(); |
|
|
for (auto &in_data_node : in_data_nodes) { |
|
|
|
|
|
// node gen_mask is located at different place in the fused node |
|
|
|
|
|
if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { |
|
|
|
|
|
continue; |
|
|
|
|
|
|
|
|
if (in_data_nodes.size() > kGenMaskInputIndex) { |
|
|
|
|
|
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); |
|
|
|
|
|
for (auto &in_data_node : in_data_nodes) { |
|
|
|
|
|
// node gen_mask is located at different place in the fused node |
|
|
|
|
|
if (in_data_node->GetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) { |
|
|
|
|
|
gen_mask = in_data_node; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
NodePtr &gen_mask = in_data_node; |
|
|
|
|
|
|
|
|
|
|
|
if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { |
|
|
if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { |
|
|
continue; |
|
|
continue; |
|
|