diff --git a/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc b/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc index 511ddece..716cc91d 100644 --- a/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc @@ -43,7 +43,7 @@ ut::GraphBuilder Graph1Builder() { ut::GraphBuilder builder = ut::GraphBuilder("g1"); auto const1 = builder.AddNode("const1", "Const", 0, 1); auto const2 = builder.AddNode("const2", "Const", 0, 1); - auto gen_mask1 = builder.AddNode("gen_mask1", "DropOutGenMask", 2, 1); + auto gen_mask1 = builder.AddNode("gen_mask1_DropOutGenMask", "DropOutGenMask", 2, 1); auto gen_mask2 = builder.AddNode("gen_mask2", "DropOutGenMaskV3", 2, 1); auto gen_mask3 = builder.AddNode("gen_mask3", "DropOutGenMaskV3D", 2, 1); auto do_mask1 = builder.AddNode("do_mask1", "DropOutDoMask", 3, 1); @@ -106,6 +106,6 @@ TEST_F(UtestLinkGenMaskNodesPass, link_gen_mask_nodes_pass_success) { auto out_ctrl_nodes = gen_mask2->GetOutControlNodes(); EXPECT_EQ(out_ctrl_nodes.size(), 1); auto out_ctrl_node = out_ctrl_nodes.at(0); - EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1"); + EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1_DropOutGenMask"); } } // namespace ge