Browse Source

fix gen_mask control-edges bug

tags/v1.3.0
dingshihao2 3 years ago
parent
commit
dfd18d3465
1 changed files with 8 additions and 5 deletions
  1. +8
    -5
      ge/graph/passes/link_gen_mask_nodes_pass.cc

+ 8
- 5
ge/graph/passes/link_gen_mask_nodes_pass.cc View File

@@ -28,6 +28,7 @@ using std::vector;


namespace ge { namespace ge {
namespace { namespace {
const size_t kGenMaskInputIndex = 1;
const size_t kDefaultMaxParallelNum = 1; const size_t kDefaultMaxParallelNum = 1;
} // namespace } // namespace


@@ -104,12 +105,14 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node
} }


auto in_data_nodes = node->GetInDataNodes(); auto in_data_nodes = node->GetInDataNodes();
for (auto &in_data_node : in_data_nodes) {
// node gen_mask is located at different place in the fused node
if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) {
continue;
if (in_data_nodes.size() > kGenMaskInputIndex) {
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex);
for (auto &in_data_node : in_data_nodes) {
// node gen_mask is located at different place in the fused node
if (in_data_node->GetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) {
gen_mask = in_data_node;
}
} }
NodePtr &gen_mask = in_data_node;


if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) {
continue; continue;


Loading…
Cancel
Save