| @@ -29,6 +29,8 @@ | |||||
| #define protected public | #define protected public | ||||
| #define private public | #define private public | ||||
| #include "init/gelib.h" | |||||
| #include "ge/opskernel_manager/ops_kernel_builder_manager.h" | |||||
| #include "graph/build/task_generator.h" | #include "graph/build/task_generator.h" | ||||
| #include "graph/manager/graph_mem_manager.h" | #include "graph/manager/graph_mem_manager.h" | ||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| @@ -41,9 +43,46 @@ using namespace ge; | |||||
| namespace { | namespace { | ||||
| const char *const kIsInputVar = "INPUT_IS_VAR"; | const char *const kIsInputVar = "INPUT_IS_VAR"; | ||||
| const char *const kIsOutputVar = "OUTPUT_IS_VAR"; | const char *const kIsOutputVar = "OUTPUT_IS_VAR"; | ||||
| } | |||||
| const char *const kKernelInfoNameHccl = "ops_kernel_info_hccl"; | |||||
| } // namespace | |||||
| class UtestTaskGeneratorTest : public testing::Test { | class UtestTaskGeneratorTest : public testing::Test { | ||||
| public: | public: | ||||
| struct FakeOpsKernelBuilder : OpsKernelBuilder { | |||||
| FakeOpsKernelBuilder(){}; | |||||
| private: | |||||
| Status Initialize(const map<std::string, std::string> &options) override { | |||||
| return SUCCESS; | |||||
| }; | |||||
| Status Finalize() override { | |||||
| return SUCCESS; | |||||
| }; | |||||
| Status CalcOpRunningParam(Node &node) override { | |||||
| return SUCCESS; | |||||
| }; | |||||
| Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) override { | |||||
| domi::TaskDef task_def; | |||||
| tasks.push_back(task_def); | |||||
| return SUCCESS; | |||||
| }; | |||||
| }; | |||||
| struct FakeOpsKernelInfoStore : OpsKernelInfoStore { | |||||
| FakeOpsKernelInfoStore() = default; | |||||
| private: | |||||
| Status Initialize(const std::map<std::string, std::string> &options) override { | |||||
| return SUCCESS; | |||||
| }; | |||||
| Status Finalize() override { | |||||
| return SUCCESS; | |||||
| }; | |||||
| bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override { | |||||
| return true; | |||||
| }; | |||||
| void GetAllOpsKernelInfo(std::map<std::string, ge::OpInfo> &infos) const override{}; | |||||
| }; | |||||
| ge::ComputeGraphPtr BuildGraphFpProfiling() { | ge::ComputeGraphPtr BuildGraphFpProfiling() { | ||||
| ge::ut::GraphBuilder builder("graph"); | ge::ut::GraphBuilder builder("graph"); | ||||
| auto data = builder.AddNode("data", "phony", 1, 1); | auto data = builder.AddNode("data", "phony", 1, 1); | ||||
| @@ -95,6 +134,14 @@ class UtestTaskGeneratorTest : public testing::Test { | |||||
| return builder.GetGraph(); | return builder.GetGraph(); | ||||
| } | } | ||||
| ge::ComputeGraphPtr BuildHcclGraph() { | |||||
| ge::ut::GraphBuilder builder("graph"); | |||||
| auto hccl_node = builder.AddNode("hccl_phony_node", "HCCL_PHONY", 0, 0); | |||||
| auto op_desc = hccl_node->GetOpDesc(); | |||||
| op_desc->SetOpKernelLibName(kKernelInfoNameHccl); | |||||
| op_desc->SetStreamId(0); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| protected: | protected: | ||||
| void SetUp() {} | void SetUp() {} | ||||
| @@ -156,3 +203,31 @@ TEST_F(UtestTaskGeneratorTest, AutoFindBpOpIndex) { | |||||
| output_desc->SetName("hcom"); | output_desc->SetName("hcom"); | ||||
| EXPECT_EQ(task_generator.AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes), SUCCESS); | EXPECT_EQ(task_generator.AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes), SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestTaskGeneratorTest, GenerateTask) { | |||||
| map<string, string> options; | |||||
| Status ret = ge::GELib::Initialize(options); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| EXPECT_NE(instance_ptr, nullptr); | |||||
| OpsKernelInfoStorePtr ops_kernel_info_store_ptr = MakeShared<FakeOpsKernelInfoStore>(); | |||||
| instance_ptr->opsManager_.ops_kernel_store_.insert(make_pair(kKernelInfoNameHccl, ops_kernel_info_store_ptr)); | |||||
| OpsKernelBuilderManager &builder_manager_instance_ptr = ge::OpsKernelBuilderManager::Instance(); | |||||
| OpsKernelBuilderPtr fake_builder = MakeShared<FakeOpsKernelBuilder>(); | |||||
| builder_manager_instance_ptr.ops_kernel_builders_[kKernelInfoNameHccl] = fake_builder; | |||||
| auto graph = BuildHcclGraph(); | |||||
| TaskGenerator task_generator(nullptr, 0); | |||||
| RunContext run_context; | |||||
| run_context.graphStreamList.push_back(static_cast<void *>(ops_kernel_info_store_ptr.get())); | |||||
| vector<uint32_t> all_reduce_nodes; | |||||
| vector<domi::TaskDef> task_def_list; | |||||
| map<uint32_t, string> op_name_map; | |||||
| EXPECT_EQ(task_generator.GenerateTask(run_context, graph, task_def_list, op_name_map), SUCCESS); | |||||
| EXPECT_EQ(task_def_list.size(), 1); | |||||
| EXPECT_EQ(task_def_list[0].ops_kernel_store_ptr(), reinterpret_cast<uintptr_t>(ops_kernel_info_store_ptr.get())); | |||||
| } | |||||