diff --git a/tests/ut/ge/graph/build/task_generator_unittest.cc b/tests/ut/ge/graph/build/task_generator_unittest.cc index 1e865050..7be20fa1 100644 --- a/tests/ut/ge/graph/build/task_generator_unittest.cc +++ b/tests/ut/ge/graph/build/task_generator_unittest.cc @@ -29,6 +29,8 @@ #define protected public #define private public +#include "init/gelib.h" +#include "ge/opskernel_manager/ops_kernel_builder_manager.h" #include "graph/build/task_generator.h" #include "graph/manager/graph_mem_manager.h" #include "graph/manager/graph_var_manager.h" @@ -41,9 +43,46 @@ using namespace ge; namespace { const char *const kIsInputVar = "INPUT_IS_VAR"; const char *const kIsOutputVar = "OUTPUT_IS_VAR"; -} +const char *const kKernelInfoNameHccl = "ops_kernel_info_hccl"; +} // namespace class UtestTaskGeneratorTest : public testing::Test { public: + struct FakeOpsKernelBuilder : OpsKernelBuilder { + FakeOpsKernelBuilder(){}; + + private: + Status Initialize(const map &options) override { + return SUCCESS; + }; + Status Finalize() override { + return SUCCESS; + }; + Status CalcOpRunningParam(Node &node) override { + return SUCCESS; + }; + Status GenerateTask(const Node &node, RunContext &context, std::vector &tasks) override { + domi::TaskDef task_def; + tasks.push_back(task_def); + return SUCCESS; + }; + }; + + struct FakeOpsKernelInfoStore : OpsKernelInfoStore { + FakeOpsKernelInfoStore() = default; + + private: + Status Initialize(const std::map &options) override { + return SUCCESS; + }; + Status Finalize() override { + return SUCCESS; + }; + bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override { + return true; + }; + void GetAllOpsKernelInfo(std::map &infos) const override{}; + }; + ge::ComputeGraphPtr BuildGraphFpProfiling() { ge::ut::GraphBuilder builder("graph"); auto data = builder.AddNode("data", "phony", 1, 1); @@ -95,6 +134,14 @@ class UtestTaskGeneratorTest : public testing::Test { return builder.GetGraph(); } + ge::ComputeGraphPtr BuildHcclGraph() { + ge::ut::GraphBuilder builder("graph"); + auto hccl_node = builder.AddNode("hccl_phony_node", "HCCL_PHONY", 0, 0); + auto op_desc = hccl_node->GetOpDesc(); + op_desc->SetOpKernelLibName(kKernelInfoNameHccl); + op_desc->SetStreamId(0); + return builder.GetGraph(); + } protected: void SetUp() {} @@ -156,3 +203,31 @@ TEST_F(UtestTaskGeneratorTest, AutoFindBpOpIndex) { output_desc->SetName("hcom"); EXPECT_EQ(task_generator.AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes), SUCCESS); } + +TEST_F(UtestTaskGeneratorTest, GenerateTask) { + map options; + Status ret = ge::GELib::Initialize(options); + EXPECT_EQ(ret, SUCCESS); + + shared_ptr instance_ptr = ge::GELib::GetInstance(); + EXPECT_NE(instance_ptr, nullptr); + + OpsKernelInfoStorePtr ops_kernel_info_store_ptr = MakeShared(); + instance_ptr->opsManager_.ops_kernel_store_.insert(make_pair(kKernelInfoNameHccl, ops_kernel_info_store_ptr)); + + OpsKernelBuilderManager &builder_manager_instance_ptr = ge::OpsKernelBuilderManager::Instance(); + OpsKernelBuilderPtr fake_builder = MakeShared(); + builder_manager_instance_ptr.ops_kernel_builders_[kKernelInfoNameHccl] = fake_builder; + + auto graph = BuildHcclGraph(); + TaskGenerator task_generator(nullptr, 0); + RunContext run_context; + run_context.graphStreamList.push_back(static_cast(ops_kernel_info_store_ptr.get())); + vector all_reduce_nodes; + vector task_def_list; + map op_name_map; + + EXPECT_EQ(task_generator.GenerateTask(run_context, graph, task_def_list, op_name_map), SUCCESS); + EXPECT_EQ(task_def_list.size(), 1); + EXPECT_EQ(task_def_list[0].ops_kernel_store_ptr(), reinterpret_cast(ops_kernel_info_store_ptr.get())); +} \ No newline at end of file