diff --git a/tests/ut/ge/hybrid/known_node_executor_unittest.cc b/tests/ut/ge/hybrid/known_node_executor_unittest.cc index 98e985f7..7d32f712 100644 --- a/tests/ut/ge/hybrid/known_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/known_node_executor_unittest.cc @@ -88,4 +88,33 @@ TEST_F(UnknownNodeExecutorTest, TestParseAttrForAllocatingOutputs) { ASSERT_EQ(node_item.ref_outputs[1], const_node); ASSERT_EQ(node_item.reuse_inputs.size(), 1); ASSERT_EQ(node_item.reuse_inputs[0], 0); +} + +TEST_F(UnknownNodeExecutorTest, TestSetGlobalStepInLoadTask) { +OpDescPtr op_desc = CreateOpDesc("PartitionedCall", "PartitionedCall"); +auto node = root_graph->AddNode(op_desc); +node->SetOwnerComputeGraph(root_graph); +auto sub_graph = BuildDataDirectConnectGraph(); +sub_graph->SetParentGraph(root_graph); +sub_graph->SetParentNode(node); +node->GetOpDesc()->AddSubgraphName("subgraph"); +node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph"); +root_graph->AddSubgraph("subgraph", sub_graph); + +ComputeGraphPtr root_graph = std::make_shared("root_graph"); +GeRootModelPtr ge_root_model = make_shared(root_graph); +HybridModel hybrid_model(ge_root_model); +hybrid_model.model_id_ = 0; +hybrid_model.model_name_ = "root_model"; +hybrid_model.om_name_ = "temp_om"; +auto *step_id = new int64_t[1]; +step_id[0] = 520; +TensorValue tensor_value((void *)step_id, sizeof(step_id)); +hybrid_model.variable_tensors_.insert({"ge_global_step", make_unique(tensor_value)}); + +KnownNodeExecutor known_node_executor; +shared_ptr task = nullptr; +Status ret = known_node_executor.LoadTask(hybrid_model, node, task); +EXPECT_EQ(*(task->davinci_model_->global_step_addr_), 520); +EXPECT_EQ(known_node_executor.LoadTask(hybrid_model, node, task), SUCCESS); } \ No newline at end of file