Browse Source

bugfix for taskdef's random variation in offline case

tags/v1.5.1
gengchao4@huawei.com 3 years ago
parent
commit
3dc9881cd6
1 changed files with 76 additions and 1 deletions
  1. +76
    -1
      tests/ut/ge/graph/build/task_generator_unittest.cc

+ 76
- 1
tests/ut/ge/graph/build/task_generator_unittest.cc View File

@@ -29,6 +29,8 @@


#define protected public #define protected public
#define private public #define private public
#include "init/gelib.h"
#include "ge/opskernel_manager/ops_kernel_builder_manager.h"
#include "graph/build/task_generator.h" #include "graph/build/task_generator.h"
#include "graph/manager/graph_mem_manager.h" #include "graph/manager/graph_mem_manager.h"
#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"
@@ -41,9 +43,46 @@ using namespace ge;
namespace { namespace {
const char *const kIsInputVar = "INPUT_IS_VAR"; const char *const kIsInputVar = "INPUT_IS_VAR";
const char *const kIsOutputVar = "OUTPUT_IS_VAR"; const char *const kIsOutputVar = "OUTPUT_IS_VAR";
}
const char *const kKernelInfoNameHccl = "ops_kernel_info_hccl";
} // namespace
class UtestTaskGeneratorTest : public testing::Test { class UtestTaskGeneratorTest : public testing::Test {
public: public:
struct FakeOpsKernelBuilder : OpsKernelBuilder {
FakeOpsKernelBuilder(){};

private:
Status Initialize(const map<std::string, std::string> &options) override {
return SUCCESS;
};
Status Finalize() override {
return SUCCESS;
};
Status CalcOpRunningParam(Node &node) override {
return SUCCESS;
};
Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) override {
domi::TaskDef task_def;
tasks.push_back(task_def);
return SUCCESS;
};
};

struct FakeOpsKernelInfoStore : OpsKernelInfoStore {
FakeOpsKernelInfoStore() = default;

private:
Status Initialize(const std::map<std::string, std::string> &options) override {
return SUCCESS;
};
Status Finalize() override {
return SUCCESS;
};
bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override {
return true;
};
void GetAllOpsKernelInfo(std::map<std::string, ge::OpInfo> &infos) const override{};
};

ge::ComputeGraphPtr BuildGraphFpProfiling() { ge::ComputeGraphPtr BuildGraphFpProfiling() {
ge::ut::GraphBuilder builder("graph"); ge::ut::GraphBuilder builder("graph");
auto data = builder.AddNode("data", "phony", 1, 1); auto data = builder.AddNode("data", "phony", 1, 1);
@@ -95,6 +134,14 @@ class UtestTaskGeneratorTest : public testing::Test {


return builder.GetGraph(); return builder.GetGraph();
} }
ge::ComputeGraphPtr BuildHcclGraph() {
ge::ut::GraphBuilder builder("graph");
auto hccl_node = builder.AddNode("hccl_phony_node", "HCCL_PHONY", 0, 0);
auto op_desc = hccl_node->GetOpDesc();
op_desc->SetOpKernelLibName(kKernelInfoNameHccl);
op_desc->SetStreamId(0);
return builder.GetGraph();
}


protected: protected:
void SetUp() {} void SetUp() {}
@@ -156,3 +203,31 @@ TEST_F(UtestTaskGeneratorTest, AutoFindBpOpIndex) {
output_desc->SetName("hcom"); output_desc->SetName("hcom");
EXPECT_EQ(task_generator.AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes), SUCCESS); EXPECT_EQ(task_generator.AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes), SUCCESS);
} }

TEST_F(UtestTaskGeneratorTest, GenerateTask) {
map<string, string> options;
Status ret = ge::GELib::Initialize(options);
EXPECT_EQ(ret, SUCCESS);

shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
EXPECT_NE(instance_ptr, nullptr);

OpsKernelInfoStorePtr ops_kernel_info_store_ptr = MakeShared<FakeOpsKernelInfoStore>();
instance_ptr->opsManager_.ops_kernel_store_.insert(make_pair(kKernelInfoNameHccl, ops_kernel_info_store_ptr));

OpsKernelBuilderManager &builder_manager_instance_ptr = ge::OpsKernelBuilderManager::Instance();
OpsKernelBuilderPtr fake_builder = MakeShared<FakeOpsKernelBuilder>();
builder_manager_instance_ptr.ops_kernel_builders_[kKernelInfoNameHccl] = fake_builder;

auto graph = BuildHcclGraph();
TaskGenerator task_generator(nullptr, 0);
RunContext run_context;
run_context.graphStreamList.push_back(static_cast<void *>(ops_kernel_info_store_ptr.get()));
vector<uint32_t> all_reduce_nodes;
vector<domi::TaskDef> task_def_list;
map<uint32_t, string> op_name_map;

EXPECT_EQ(task_generator.GenerateTask(run_context, graph, task_def_list, op_name_map), SUCCESS);
EXPECT_EQ(task_def_list.size(), 1);
EXPECT_EQ(task_def_list[0].ops_kernel_store_ptr(), reinterpret_cast<uintptr_t>(ops_kernel_info_store_ptr.get()));
}

Loading…
Cancel
Save