|
|
@@ -43,7 +43,7 @@ ut::GraphBuilder Graph1Builder() { |
|
|
|
ut::GraphBuilder builder = ut::GraphBuilder("g1"); |
|
|
|
auto const1 = builder.AddNode("const1", "Const", 0, 1); |
|
|
|
auto const2 = builder.AddNode("const2", "Const", 0, 1); |
|
|
|
auto gen_mask1 = builder.AddNode("gen_mask1", "DropOutGenMask", 2, 1); |
|
|
|
auto gen_mask1 = builder.AddNode("gen_mask1_DropOutGenMask", "DropOutGenMask", 2, 1); |
|
|
|
auto gen_mask2 = builder.AddNode("gen_mask2", "DropOutGenMaskV3", 2, 1); |
|
|
|
auto gen_mask3 = builder.AddNode("gen_mask3", "DropOutGenMaskV3D", 2, 1); |
|
|
|
auto do_mask1 = builder.AddNode("do_mask1", "DropOutDoMask", 3, 1); |
|
|
@@ -106,6 +106,6 @@ TEST_F(UtestLinkGenMaskNodesPass, link_gen_mask_nodes_pass_success) { |
|
|
|
auto out_ctrl_nodes = gen_mask2->GetOutControlNodes(); |
|
|
|
EXPECT_EQ(out_ctrl_nodes.size(), 1); |
|
|
|
auto out_ctrl_node = out_ctrl_nodes.at(0); |
|
|
|
EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1"); |
|
|
|
EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1_DropOutGenMask"); |
|
|
|
} |
|
|
|
} // namespace ge |