Browse Source

!1541 [cube] fix gen_mask control-edges bug

From: @ding-shihao
Reviewed-by: 
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
3dcbde009c
2 changed files with 12 additions and 2 deletions
  1. +10
    -0
      ge/graph/passes/link_gen_mask_nodes_pass.cc
  2. +2
    -2
      tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc

+ 10
- 0
ge/graph/passes/link_gen_mask_nodes_pass.cc View File

@@ -107,6 +107,16 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node
auto in_data_nodes = node->GetInDataNodes();
if (in_data_nodes.size() > kGenMaskInputIndex) {
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex);
for (auto &in_data_node : in_data_nodes) {
// node gen_mask is located at different place in the fused node
if (in_data_node->GetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) {
gen_mask = in_data_node;
GELOGD("The fused node type [%s], paired with the input node name [%s].",
node->GetType().c_str(), gen_mask->GetName().c_str());
break;
}
}

if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) {
continue;
}


+ 2
- 2
tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc View File

@@ -43,7 +43,7 @@ ut::GraphBuilder Graph1Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1");
auto const1 = builder.AddNode("const1", "Const", 0, 1);
auto const2 = builder.AddNode("const2", "Const", 0, 1);
auto gen_mask1 = builder.AddNode("gen_mask1", "DropOutGenMask", 2, 1);
auto gen_mask1 = builder.AddNode("gen_mask1_DropOutGenMask", "DropOutGenMask", 2, 1);
auto gen_mask2 = builder.AddNode("gen_mask2", "DropOutGenMaskV3", 2, 1);
auto gen_mask3 = builder.AddNode("gen_mask3", "DropOutGenMaskV3D", 2, 1);
auto do_mask1 = builder.AddNode("do_mask1", "DropOutDoMask", 3, 1);
@@ -106,6 +106,6 @@ TEST_F(UtestLinkGenMaskNodesPass, link_gen_mask_nodes_pass_success) {
auto out_ctrl_nodes = gen_mask2->GetOutControlNodes();
EXPECT_EQ(out_ctrl_nodes.size(), 1);
auto out_ctrl_node = out_ctrl_nodes.at(0);
EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1");
EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1_DropOutGenMask");
}
} // namespace ge

Loading…
Cancel
Save