|
|
@@ -29,6 +29,8 @@ |
|
|
|
|
|
|
|
#define protected public |
|
|
|
#define private public |
|
|
|
#include "init/gelib.h" |
|
|
|
#include "ge/opskernel_manager/ops_kernel_builder_manager.h" |
|
|
|
#include "graph/build/task_generator.h" |
|
|
|
#include "graph/manager/graph_mem_manager.h" |
|
|
|
#include "graph/manager/graph_var_manager.h" |
|
|
@@ -41,9 +43,46 @@ using namespace ge; |
|
|
|
namespace { |
|
|
|
const char *const kIsInputVar = "INPUT_IS_VAR"; |
|
|
|
const char *const kIsOutputVar = "OUTPUT_IS_VAR"; |
|
|
|
} |
|
|
|
const char *const kKernelInfoNameHccl = "ops_kernel_info_hccl"; |
|
|
|
} // namespace |
|
|
|
class UtestTaskGeneratorTest : public testing::Test { |
|
|
|
public: |
|
|
|
struct FakeOpsKernelBuilder : OpsKernelBuilder { |
|
|
|
FakeOpsKernelBuilder(){}; |
|
|
|
|
|
|
|
private: |
|
|
|
Status Initialize(const map<std::string, std::string> &options) override { |
|
|
|
return SUCCESS; |
|
|
|
}; |
|
|
|
Status Finalize() override { |
|
|
|
return SUCCESS; |
|
|
|
}; |
|
|
|
Status CalcOpRunningParam(Node &node) override { |
|
|
|
return SUCCESS; |
|
|
|
}; |
|
|
|
Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) override { |
|
|
|
domi::TaskDef task_def; |
|
|
|
tasks.push_back(task_def); |
|
|
|
return SUCCESS; |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
struct FakeOpsKernelInfoStore : OpsKernelInfoStore { |
|
|
|
FakeOpsKernelInfoStore() = default; |
|
|
|
|
|
|
|
private: |
|
|
|
Status Initialize(const std::map<std::string, std::string> &options) override { |
|
|
|
return SUCCESS; |
|
|
|
}; |
|
|
|
Status Finalize() override { |
|
|
|
return SUCCESS; |
|
|
|
}; |
|
|
|
bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override { |
|
|
|
return true; |
|
|
|
}; |
|
|
|
void GetAllOpsKernelInfo(std::map<std::string, ge::OpInfo> &infos) const override{}; |
|
|
|
}; |
|
|
|
|
|
|
|
ge::ComputeGraphPtr BuildGraphFpProfiling() { |
|
|
|
ge::ut::GraphBuilder builder("graph"); |
|
|
|
auto data = builder.AddNode("data", "phony", 1, 1); |
|
|
@@ -95,6 +134,14 @@ class UtestTaskGeneratorTest : public testing::Test { |
|
|
|
|
|
|
|
return builder.GetGraph(); |
|
|
|
} |
|
|
|
ge::ComputeGraphPtr BuildHcclGraph() { |
|
|
|
ge::ut::GraphBuilder builder("graph"); |
|
|
|
auto hccl_node = builder.AddNode("hccl_phony_node", "HCCL_PHONY", 0, 0); |
|
|
|
auto op_desc = hccl_node->GetOpDesc(); |
|
|
|
op_desc->SetOpKernelLibName(kKernelInfoNameHccl); |
|
|
|
op_desc->SetStreamId(0); |
|
|
|
return builder.GetGraph(); |
|
|
|
} |
|
|
|
|
|
|
|
protected: |
|
|
|
void SetUp() {} |
|
|
@@ -156,3 +203,31 @@ TEST_F(UtestTaskGeneratorTest, AutoFindBpOpIndex) { |
|
|
|
output_desc->SetName("hcom"); |
|
|
|
EXPECT_EQ(task_generator.AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes), SUCCESS); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(UtestTaskGeneratorTest, GenerateTask) { |
|
|
|
map<string, string> options; |
|
|
|
Status ret = ge::GELib::Initialize(options); |
|
|
|
EXPECT_EQ(ret, SUCCESS); |
|
|
|
|
|
|
|
shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); |
|
|
|
EXPECT_NE(instance_ptr, nullptr); |
|
|
|
|
|
|
|
OpsKernelInfoStorePtr ops_kernel_info_store_ptr = MakeShared<FakeOpsKernelInfoStore>(); |
|
|
|
instance_ptr->opsManager_.ops_kernel_store_.insert(make_pair(kKernelInfoNameHccl, ops_kernel_info_store_ptr)); |
|
|
|
|
|
|
|
OpsKernelBuilderManager &builder_manager_instance_ptr = ge::OpsKernelBuilderManager::Instance(); |
|
|
|
OpsKernelBuilderPtr fake_builder = MakeShared<FakeOpsKernelBuilder>(); |
|
|
|
builder_manager_instance_ptr.ops_kernel_builders_[kKernelInfoNameHccl] = fake_builder; |
|
|
|
|
|
|
|
auto graph = BuildHcclGraph(); |
|
|
|
TaskGenerator task_generator(nullptr, 0); |
|
|
|
RunContext run_context; |
|
|
|
run_context.graphStreamList.push_back(static_cast<void *>(ops_kernel_info_store_ptr.get())); |
|
|
|
vector<uint32_t> all_reduce_nodes; |
|
|
|
vector<domi::TaskDef> task_def_list; |
|
|
|
map<uint32_t, string> op_name_map; |
|
|
|
|
|
|
|
EXPECT_EQ(task_generator.GenerateTask(run_context, graph, task_def_list, op_name_map), SUCCESS); |
|
|
|
EXPECT_EQ(task_def_list.size(), 1); |
|
|
|
EXPECT_EQ(task_def_list[0].ops_kernel_store_ptr(), reinterpret_cast<uintptr_t>(ops_kernel_info_store_ptr.get())); |
|
|
|
} |