Browse Source

!1353 Adding dependencies by parallel groups

From: @xchu42
Reviewed-by: @ji_chen,@wqtshg
Signed-off-by: @ji_chen
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
689dab39c7
9 changed files with 337 additions and 90 deletions
  1. +159
    -63
      ge/hybrid/model/hybrid_model_builder.cc
  2. +7
    -2
      ge/hybrid/model/hybrid_model_builder.h
  3. +4
    -0
      ge/hybrid/model/node_item.cc
  4. +2
    -0
      ge/hybrid/model/node_item.h
  5. +23
    -20
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc
  6. +5
    -3
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h
  7. +1
    -0
      tests/ut/ge/CMakeLists.txt
  8. +74
    -2
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc
  9. +62
    -0
      tests/ut/ge/hybrid/known_node_executor_unittest.cc

+ 159
- 63
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -255,9 +255,7 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
(void) AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false);
(void) AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false);

new_node->node_id = node_index;
new_node->op_desc->SetId(node_index);
node_index += 1;
new_node->node_id = static_cast<int>(new_node->op_desc->GetId());
NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node);
new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) ||
(executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) ||
@@ -279,10 +277,10 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt
}

Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) {
std::set<NodePtr> dependent_input_nodes;
std::set<NodePtr> dependent_for_shape_inference;
std::set<NodePtr> dependent_for_execution;
auto &ge_node = node_item.node;
bool is_hccl_op =
NodeExecutorManager::GetInstance().ResolveExecutorType(*ge_node) == NodeExecutorManager::ExecutorType::HCCL;
bool is_hccl_op = node_item.IsHcclOp();

// The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE.
// Wait for these parent nodes before execution.
@@ -297,29 +295,15 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item);

if (is_hccl_op) {
GELOGD("[%s] Add input data dependent node [%s] due to engine type is HCCL",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(src_node);
node_item.has_observer = true;
for (auto &dst_node : ge_node->GetOutNodes()) {
if (dst_node == nullptr) {
continue;
}

NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(dst_node, &dst_node_item));
dst_node_item->dependents_for_execution.emplace_back(ge_node);
}
} else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) {
GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());

if (src_node_item->shape_inference_type == DEPEND_COMPUTE || is_hccl_op || src_node_item->IsHcclOp()) {
GELOGD("[%s](%s) Add input data dependent node [%s](%s), shape inference type = %d",
ge_node->GetName().c_str(),
ge_node->GetType().c_str(),
src_node->GetName().c_str(),
src_node->GetType().c_str(),
static_cast<int>(src_node_item->shape_inference_type));
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(src_node);
dependent_for_execution.emplace(src_node);
}

if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) {
@@ -327,22 +311,17 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());
src_node_item->has_observer = true;
dependent_input_nodes.emplace(src_node);
dependent_for_shape_inference.emplace(src_node);
}
}

// cond or branch need to be prepared before the execution of IF or CASE
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) {
const auto &in_anchor = ge_node->GetInDataAnchor(0);
GE_CHECK_NOTNULL(in_anchor);
const auto &peer_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_anchor);
auto src_node = peer_anchor->GetOwnerNode();
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input
GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item);
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(src_node);
dependent_for_execution.emplace(src_node);
GELOGD("[%s] Dependent added from %s for control op's cond/branch",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());
@@ -366,24 +345,32 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node);
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());
src_node_item->has_observer = true;

dependent_input_nodes.emplace(src_node);
dependent_for_shape_inference.emplace(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}

for (const auto &dep_node : dependent_input_nodes) {
GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item, dependent_for_shape_inference));
for (const auto &dep_node : dependent_for_shape_inference) {
auto src_node_item = MutableNodeItem(dep_node);
GE_CHECK_NOTNULL(src_node_item);
src_node_item->has_observer = true;
node_item.dependents_for_shape_inference.emplace_back(dep_node);
}

GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item));
for (const auto &dep_node : dependent_for_execution) {
auto src_node_item = MutableNodeItem(dep_node);
GE_CHECK_NOTNULL(src_node_item);
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(dep_node);
}

return SUCCESS;
}

Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) {
Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies) {
if (node_item.fused_subgraph == nullptr) {
return SUCCESS;
}
@@ -413,17 +400,12 @@ Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) {
node_item.NodeName().c_str(),
op_desc->GetName().c_str(),
src_node_item->NodeName().c_str());
src_node_item->has_observer = true;
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());

auto &depends = node_item.dependents_for_shape_inference;
if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) {
depends.emplace_back(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}
dependencies.emplace(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}

return SUCCESS;
@@ -770,9 +752,23 @@ Status HybridModelBuilder::LoadGraph() {
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize());
GE_DUMP(root_graph, "hybrid_merged_graph");
}

root_graph_ = root_graph;
// Reset node id by topological order across all subgraphs
int64_t index = 0;
for (const auto &node : root_graph->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
auto parent_graph = node->GetOwnerComputeGraph();
// No need to update nodes in known subgraph
if (parent_graph != nullptr && !parent_graph->GetGraphUnknownFlag()) {
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
op_desc->SetId(index++);
}
GE_DUMP(root_graph, "hybrid_merged_graph");
GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph.");
GELOGD("Done loading root graph successfully.");
GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph");
@@ -810,6 +806,7 @@ Status HybridModelBuilder::LoadGraph() {
}
}

GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops");
GELOGI("Done loading all subgraphs successfully.");
return SUCCESS;
}
@@ -1075,25 +1072,38 @@ Status HybridModelBuilder::InitWeights() {
return SUCCESS;
}

Status HybridModelBuilder::LoadTask(NodeItem &node_item) {
auto &node_ptr = node_item.node;
GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str());
auto load_ret = node_item.node_executor->LoadTask(hybrid_model_,
node_ptr,
node_item.kernel_task);
if (load_ret != UNSUPPORTED && load_ret != SUCCESS) {
GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str());
return load_ret;
}

GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str());
return SUCCESS;
}

Status HybridModelBuilder::LoadTasks() {
GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed.");
std::map<int64_t, NodeItem *> ordered_partitioned_calls;
for (auto &it : hybrid_model_.node_items_) {
auto &node_item = it.second;
auto &node_ptr = node_item->node;
if (node_item->node_type == NETOUTPUT) {
continue;
}

GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str());
auto load_ret = node_item->node_executor->LoadTask(hybrid_model_,
node_ptr,
node_item->kernel_task);
if (load_ret != UNSUPPORTED && load_ret != SUCCESS) {
GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str());
return load_ret;
if (node_item->node_type == PARTITIONEDCALL) {
ordered_partitioned_calls.emplace(node_item->node_id, node_item.get());
}
GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item));
}

GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str());
// HCCL operators need to be loaded in the same order across different processes
for (auto &it : ordered_partitioned_calls) {
GE_CHK_STATUS_RET_NOLOG(LoadTask(*it.second));
}

return SUCCESS;
@@ -1905,6 +1915,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root
NodeItem *node_item = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item));
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item));
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item));
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task

node_item->input_start = input_start;
@@ -2011,5 +2022,90 @@ Status HybridModelBuilder::CheckAicpuOpList() {
"Launch check aicpu op type failed.");
return SUCCESS;
}

Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
const auto &node = node_item->node;
auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node);
if (executor_type == NodeExecutorManager::ExecutorType::HCCL) {
std::string parallel_group;
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) {
GELOGD("[%s] Got parallel group = %s", node_item->NodeName().c_str(), parallel_group.c_str());
parallel_group_to_nodes_[parallel_group].emplace(node_item);
std::set<std::string> group{parallel_group};
node_to_parallel_groups_[node_item].emplace(parallel_group);
}
} else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) {
std::set<std::string> parallel_groups;
GELOGD("[%s] Parse parallel group for known-shaped subgraph", node_item->NodeName().c_str());
for (const auto &subgraph_name : node->GetOpDesc()->GetSubgraphInstanceNames()) {
GELOGD("[%s] Start to get parallel group from subgraph: %s",
node_item->NodeName().c_str(),
subgraph_name.c_str());
auto subgraph = root_graph_->GetSubgraph(subgraph_name);
GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetAllNodes()) {
std::string parallel_group;
if (AttrUtils::GetStr(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) {
GELOGD("[%s::%s] Got parallel group = %s",
subgraph_name.c_str(),
sub_node->GetName().c_str(),
parallel_group.c_str());
parallel_groups.emplace(parallel_group);
}
}
}

if (!parallel_groups.empty()) {
for (const auto &parallel_group : parallel_groups) {
parallel_group_to_nodes_[parallel_group].emplace(node_item);
GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str());
}
node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups));
}
}

return SUCCESS;
}

Status HybridModelBuilder::ParseDependentByParallelGroup() {
for (const auto &it : node_to_parallel_groups_) {
auto node_item = it.first;
auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node);
for (const auto &parallel_group : it.second) {
auto &dependent_nodes = parallel_group_to_nodes_[parallel_group];
NodeItem *nearest_dep_node = nullptr;
int max_id = -1;
for (auto &dep_node : dependent_nodes) {
if (node_item == dep_node) {
continue;
}
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node);
if (src_engine_type == dst_engine_type) {
continue;
}

if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) {
nearest_dep_node = dep_node;
max_id = dep_node->node_id;
}
}

if (nearest_dep_node != nullptr) {
GELOGD("Add dependency for nodes of same parallel group[%s], src = [%s], dst = [%s]",
parallel_group.c_str(),
nearest_dep_node->NodeName().c_str(),
node_item->NodeName().c_str());
auto &deps = node_item->dependents_for_execution;
if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) {
GELOGD("Already has dependency, skip it");
continue;
}
nearest_dep_node->has_observer = true;
deps.emplace_back(nearest_dep_node->node);
}
}
}
return SUCCESS;
}
} // namespace hybrid
} // namespace ge

+ 7
- 2
ge/hybrid/model/hybrid_model_builder.h View File

@@ -57,14 +57,17 @@ class HybridModelBuilder {
Status ValidateParams();
Status LoadGraph();
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model);
Status LoadTask(NodeItem &node_item);
Status LoadTasks();
Status IdentifyVariableOutputs(NodeItem &node_item);
Status IdentifySameInputs(NodeItem &node_item);
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item);
Status CollectParallelGroups(NodeItem *node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);
Status ParseDependentForFusedSubgraph(NodeItem &node_item);
Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies);
Status ParseDependentByParallelGroup();
Status IndexTaskDefs();
Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model);
Status IndexSpecialNodes();
@@ -97,12 +100,14 @@ class HybridModelBuilder {
NodeItem *MutableNodeItem(const NodePtr &node);

GeRootModelPtr ge_root_model_;
ComputeGraphPtr root_graph_;
std::map<std::string, GeModelPtr> subgraph_models_;
std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_;

HybridModel &hybrid_model_;
std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_;
int node_index = 0;

RuntimeParam &runtime_param_;
VarManager *var_manager_ = nullptr;


+ 4
- 0
ge/hybrid/model/node_item.cc View File

@@ -251,6 +251,10 @@ bool NodeItem::IsControlOp() const {
return ge::hybrid::IsControlOp(op_desc->GetType());
}

bool NodeItem::IsHcclOp() const {
return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL;
}

std::string NodeItem::DebugString() const {
std::stringstream ss;
ss << "Node: ";


+ 2
- 0
ge/hybrid/model/node_item.h View File

@@ -67,6 +67,8 @@ struct NodeItem {

bool IsControlOp() const;

bool IsHcclOp() const;

void SetToDynamic();

std::string DebugString() const;


+ 23
- 20
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc View File

@@ -95,13 +95,6 @@ Status KnownNodeTask::UpdateArgs(TaskContext &context) {
Status KnownNodeTask::Init(TaskContext &context) {
// allocate output mem
GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed.");

// init davinicmodel
if (!load_flag_) {
davinci_model_->InitRuntimeParams();
GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed.");
}

// allocate mem base
void *buffer = nullptr;
if (davinci_model_->TotalMemSize() != 0) {
@@ -129,23 +122,31 @@ Status KnownNodeTask::Init(TaskContext &context) {
void *global_step = context.GetExecutionContext()->global_step;
davinci_model_->SetKnownShapeGlobalStep(global_step);
}
int32_t device_id = 0;
rtError_t rt_ret = rtGetDevice(&device_id);
if (rt_ret != RT_ERROR_NONE || device_id < 0) {
GELOGE(rt_ret, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}
davinci_model_->SetDeviceId(device_id);
GE_CHK_STATUS_RET(davinci_model_->Init(), "KnownNodeExecutor::InitDavinciModel failed.");
load_flag_ = true;
} else {
GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(),
davinci_model_->Id(), davinci_model_->SubModelId()), "KnownNodeTask::Init destroy aicpu kernel failed.");
}
GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(),
davinci_model_->Id(), davinci_model_->SubModelId()),
"KnownNodeTask::Init destroy aicpu kernel failed.");
GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName());
return SUCCESS;
}

Status KnownNodeTask::InitDavinciModel() {
GELOGD("[Init][Model] start");
davinci_model_->InitRuntimeParams();
GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed");
int32_t device_id = 0;
GE_CHK_RT_RET(rtGetDevice(&device_id));
davinci_model_->SetDeviceId(static_cast<uint32_t>(device_id));
GE_CHK_STATUS_RET(DoInitDavinciModel(), "[Init][Model] Failed to init davinci model.");
GELOGD("[Init][Model] success");
return SUCCESS;
}

Status KnownNodeTask::DoInitDavinciModel() {
return davinci_model_->Init();
}

Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const {
GELOGD("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName());
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] Start");
@@ -182,9 +183,11 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node

GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed.");

task = MakeShared<KnownNodeTask>(davinci_model);
GE_CHECK_NOTNULL(task);
auto known_node_task = MakeShared<KnownNodeTask>(davinci_model);
GE_CHECK_NOTNULL(known_node_task);
GE_CHK_STATUS_RET_NOLOG(known_node_task->InitDavinciModel());
GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str());
task = std::move(known_node_task);
return SUCCESS;
}



+ 5
- 3
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h View File

@@ -31,11 +31,15 @@ class KnownNodeTask : public NodeTask {
: davinci_model_(davinci_model)
{}

~KnownNodeTask() {}
~KnownNodeTask() = default;

Status UpdateArgs(TaskContext &context) override;
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
Status Init(TaskContext &context) override;
Status InitDavinciModel();

protected:
virtual Status DoInitDavinciModel();
private:
std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
bool load_flag_ = false;
@@ -47,8 +51,6 @@ class KnownNodeExecutor : public NodeExecutor {
Status PrepareTask(NodeTask &task, TaskContext &context) const;
Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const;
~KnownNodeExecutor() {}
private:
std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
};
} // namespace hybrid
} // namespace ge


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -797,6 +797,7 @@ set(PROFILING_MNG_TEST_FILES

set(HYBRID_TEST_FILES
"hybrid/ge_hybrid_unittest.cc"
"hybrid/known_node_executor_unittest.cc"
)

set(OTHERS_TEST_FILES


+ 74
- 2
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -19,10 +19,12 @@
#include <vector>
#include "runtime/rt.h"

#include "graph/utils/node_utils.h"
#define protected public
#define private public
#include "hybrid/model/hybrid_model_builder.h"
#include "hybrid/model/hybrid_model.h"
#include "hybrid/node_executor/node_executor.h"
#include "model/ge_model.h"
#include "model/ge_root_model.h"
#include "hybrid/node_executor/aicore/aicore_op_task.h"
@@ -51,7 +53,9 @@ class UtestGeHybrid : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
void TearDown() {
NpuMemoryAllocator::allocators_.clear();
}
};

static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") {
@@ -245,7 +249,7 @@ TEST_F(UtestGeHybrid, init_weight_success) {
ASSERT_EQ(ret,PARAM_INVALID);
}

TEST_F(UtestGeHybrid, hybrid_model_executor) {
TEST_F(UtestGeHybrid, hybrid_model_executor) {
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc");
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
HybridModel model(root_model);
@@ -256,3 +260,71 @@ TEST_F(UtestGeHybrid, init_weight_success) {
HybridModelExecutor executor(model_ptr, device_id, stream);
executor.Init();
}

TEST_F(UtestGeHybrid, test_parse_parallel_group) {
NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl",
NodeExecutorManager::ExecutorType::HCCL);
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test");
OpDescPtr op_desc = CreateOpDesc("AllReduce", "AllReduce");
op_desc->SetId(0);
ge::AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, "group_1");
auto node = compute_graph->AddNode(op_desc);
std::unique_ptr<NodeItem> node_item;
NodeItem::Create(node, node_item);
node_item->node_id = 0;

op_desc->SetOpKernelLibName("ops_kernel_info_hccl");
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
HybridModel model(root_model);

HybridModelBuilder builder(model);
builder.root_graph_ = compute_graph;
ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS);

ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1);

OpDescPtr op_desc_1 = CreateOpDesc("subgraph", "PartitionedCall");
op_desc_1->AddSubgraphName("subgraph");
auto node_1 = compute_graph->AddNode(op_desc_1);

ComputeGraphPtr subgraph = MakeShared<ComputeGraph>("subgraph");
ASSERT_EQ(NodeUtils::SetSubgraph(*node_1, 0, subgraph), GRAPH_SUCCESS);

std::unique_ptr<NodeItem> node_item_1;
NodeItem::Create(node_1, node_item_1);
node_item_1->node_id = 1;

ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS);
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1);

OpDescPtr op_desc_2 = CreateOpDesc("sub_node_1", "AllReduce");
ge::AttrUtils::SetStr(op_desc_2, ATTR_NAME_PARALLEL_GROUP, "group_1");
auto node_2 = subgraph->AddNode(op_desc_2);
ASSERT_TRUE(node_2 != nullptr);

OpDescPtr op_desc_3 = CreateOpDesc("sub_node_2", "AllReduce2");
ge::AttrUtils::SetStr(op_desc_3, ATTR_NAME_PARALLEL_GROUP, "group_2");
auto node_3 = subgraph->AddNode(op_desc_3);
ASSERT_TRUE(node_3 != nullptr);

ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS);
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 2);
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 2);
ASSERT_EQ(builder.parallel_group_to_nodes_["group_1"].size(), 2);
ASSERT_EQ(builder.parallel_group_to_nodes_["group_2"].size(), 1);

ASSERT_FALSE(node_item->has_observer);
ASSERT_TRUE(node_item_1->dependents_for_execution.empty());
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
ASSERT_TRUE(node_item->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
ASSERT_EQ(node_item_1->dependents_for_execution[0], node);

// repeat parse
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
ASSERT_TRUE(node_item->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
ASSERT_EQ(node_item_1->dependents_for_execution[0], node);
}

+ 62
- 0
tests/ut/ge/hybrid/known_node_executor_unittest.cc View File

@@ -0,0 +1,62 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>
#include <memory>

#define protected public
#define private public
#include "hybrid/node_executor/compiledsubgraph/known_node_executor.h"
#undef private
#undef protected
#include "graph/manager/graph_mem_allocator.h"

using namespace std;
using namespace testing;
using namespace ge;
using namespace hybrid;

class UnknownNodeExecutorTest : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

namespace {
class KnownNodeTaskMock : public KnownNodeTask {
public:
KnownNodeTaskMock(std::shared_ptr<DavinciModel> davinci_model): KnownNodeTask(davinci_model) {};
~KnownNodeTaskMock() override = default;
MOCK_METHOD0(DoInitDavinciModel, Status());
};
}

TEST_F(UnknownNodeExecutorTest, test_init_davinci_model) {
auto davinci_model = std::make_shared<DavinciModel>(0, nullptr);
davinci_model->SetDeviceId(0);
davinci_model->SetKnownNode(true);

auto ge_model = make_shared<GeModel>();
AttrUtils::SetInt(ge_model, ATTR_MODEL_VAR_SIZE, 0);
AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, 1024);
davinci_model->Assign(ge_model);

KnownNodeTaskMock mock(davinci_model);
EXPECT_CALL(mock, DoInitDavinciModel).WillOnce(::testing::Return(SUCCESS));
ASSERT_EQ(mock.InitDavinciModel(), SUCCESS);
}

Loading…
Cancel
Save