Browse Source

fix gen_mask control-edges bug

tags/v1.3.0
stormchasingg 3 years ago
parent
commit
2865bcff6c
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc

+ 2
- 2
tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc View File

@@ -43,7 +43,7 @@ ut::GraphBuilder Graph1Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1");
auto const1 = builder.AddNode("const1", "Const", 0, 1);
auto const2 = builder.AddNode("const2", "Const", 0, 1);
auto gen_mask1 = builder.AddNode("gen_mask1", "DropOutGenMask", 2, 1);
auto gen_mask1 = builder.AddNode("gen_mask1_DropOutGenMask", "DropOutGenMask", 2, 1);
auto gen_mask2 = builder.AddNode("gen_mask2", "DropOutGenMaskV3", 2, 1);
auto gen_mask3 = builder.AddNode("gen_mask3", "DropOutGenMaskV3D", 2, 1);
auto do_mask1 = builder.AddNode("do_mask1", "DropOutDoMask", 3, 1);
@@ -106,6 +106,6 @@ TEST_F(UtestLinkGenMaskNodesPass, link_gen_mask_nodes_pass_success) {
auto out_ctrl_nodes = gen_mask2->GetOutControlNodes();
EXPECT_EQ(out_ctrl_nodes.size(), 1);
auto out_ctrl_node = out_ctrl_nodes.at(0);
EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1");
EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1_DropOutGenMask");
}
} // namespace ge

Loading…
Cancel
Save