Browse Source

update UT

tags/v1.3.0
chuxing 4 years ago
parent
commit
deebe05906
3 changed files with 88 additions and 15 deletions
  1. +13
    -10
      ge/hybrid/model/hybrid_model_builder.cc
  2. +4
    -4
      ge/hybrid/model/hybrid_model_builder.h
  3. +71
    -1
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 13
- 10
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -806,7 +806,7 @@ Status HybridModelBuilder::LoadGraph() {
}
}

GE_CHK_STATUS_RET(ParseDependentForHcclNodes(), "Failed to establish dependencies for hccl ops");
GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops");
GELOGI("Done loading all subgraphs successfully.");
return SUCCESS;
}
@@ -1907,7 +1907,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root
NodeItem *node_item = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item));
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item));
GE_CHK_STATUS_RET_NOLOG(ParseParallelGroups(node_item));
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item));
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task

node_item->input_start = input_start;
@@ -2015,16 +2015,16 @@ Status HybridModelBuilder::CheckAicpuOpList() {
return SUCCESS;
}

Status HybridModelBuilder::ParseParallelGroups(NodeItem *node_item) {
Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
const auto &node = node_item->node;
auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node);
if (executor_type == NodeExecutorManager::ExecutorType::HCCL) {
std::string parallel_group;
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) {
GELOGD("[%s] Got parallel group = %s", node_item->NodeName().c_str(), parallel_group.c_str());
group_to_nodes_[parallel_group].emplace(node_item);
parallel_group_to_nodes_[parallel_group].emplace(node_item);
std::set<std::string> group{parallel_group};
node_to_groups_[node_item].emplace(parallel_group);
node_to_parallel_groups_[node_item].emplace(parallel_group);
}
} else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) {
std::set<std::string> parallel_groups;
@@ -2049,25 +2049,28 @@ Status HybridModelBuilder::ParseParallelGroups(NodeItem *node_item) {

if (!parallel_groups.empty()) {
for (const auto &parallel_group : parallel_groups) {
group_to_nodes_[parallel_group].emplace(node_item);
parallel_group_to_nodes_[parallel_group].emplace(node_item);
GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str());
}
node_to_groups_.emplace(node_item, std::move(parallel_groups));
node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups));
}
}

return SUCCESS;
}

Status HybridModelBuilder::ParseDependentForHcclNodes() {
for (const auto &it : node_to_groups_) {
Status HybridModelBuilder::ParseDependentByParallelGroup() {
for (const auto &it : node_to_parallel_groups_) {
auto node_item = it.first;
auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node);
for (const auto &parallel_group : it.second) {
auto &dependent_nodes = group_to_nodes_[parallel_group];
auto &dependent_nodes = parallel_group_to_nodes_[parallel_group];
NodeItem *nearest_dep_node = nullptr;
int max_id = -1;
for (auto &dep_node : dependent_nodes) {
if (node_item == dep_node) {
continue;
}
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node);
if (src_engine_type == dst_engine_type) {
continue;


+ 4
- 4
ge/hybrid/model/hybrid_model_builder.h View File

@@ -63,10 +63,10 @@ class HybridModelBuilder {
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item);
Status ParseParallelGroups(NodeItem *node_item);
Status CollectParallelGroups(NodeItem *node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);
Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies);
Status ParseDependentForHcclNodes();
Status ParseDependentByParallelGroup();
Status IndexTaskDefs();
Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model);
Status IndexSpecialNodes();
@@ -102,8 +102,8 @@ class HybridModelBuilder {
ComputeGraphPtr root_graph_;
std::map<std::string, GeModelPtr> subgraph_models_;
std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, std::set<NodeItem *>> group_to_nodes_;
std::map<NodeItem *, std::set<std::string>> node_to_groups_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_;

HybridModel &hybrid_model_;
std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_;


+ 71
- 1
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -19,10 +19,12 @@
#include <vector>
#include "runtime/rt.h"

#include "graph/utils/node_utils.h"
#define protected public
#define private public
#include "hybrid/model/hybrid_model_builder.h"
#include "hybrid/model/hybrid_model.h"
#include "hybrid/node_executor/node_executor.h"
#include "model/ge_model.h"
#include "model/ge_root_model.h"
#include "hybrid/node_executor/aicore/aicore_op_task.h"
@@ -247,7 +249,7 @@ TEST_F(UtestGeHybrid, init_weight_success) {
ASSERT_EQ(ret,PARAM_INVALID);
}

TEST_F(UtestGeHybrid, hybrid_model_executor) {
TEST_F(UtestGeHybrid, hybrid_model_executor) {
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc");
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
HybridModel model(root_model);
@@ -258,3 +260,71 @@ TEST_F(UtestGeHybrid, init_weight_success) {
HybridModelExecutor executor(model_ptr, device_id, stream);
executor.Init();
}

TEST_F(UtestGeHybrid, test_parse_parallel_group) {
NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl",
NodeExecutorManager::ExecutorType::HCCL);
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test");
OpDescPtr op_desc = CreateOpDesc("AllReduce", "AllReduce");
op_desc->SetId(0);
ge::AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, "group_1");
auto node = compute_graph->AddNode(op_desc);
std::unique_ptr<NodeItem> node_item;
NodeItem::Create(node, node_item);
node_item->node_id = 0;

op_desc->SetOpKernelLibName("ops_kernel_info_hccl");
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
HybridModel model(root_model);

HybridModelBuilder builder(model);
builder.root_graph_ = compute_graph;
ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS);

ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1);

OpDescPtr op_desc_1 = CreateOpDesc("subgraph", "PartitionedCall");
op_desc_1->AddSubgraphName("subgraph");
auto node_1 = compute_graph->AddNode(op_desc_1);

ComputeGraphPtr subgraph = MakeShared<ComputeGraph>("subgraph");
ASSERT_EQ(NodeUtils::SetSubgraph(*node_1, 0, subgraph), GRAPH_SUCCESS);

std::unique_ptr<NodeItem> node_item_1;
NodeItem::Create(node_1, node_item_1);
node_item_1->node_id = 1;

ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS);
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1);

OpDescPtr op_desc_2 = CreateOpDesc("sub_node_1", "AllReduce");
ge::AttrUtils::SetStr(op_desc_2, ATTR_NAME_PARALLEL_GROUP, "group_1");
auto node_2 = subgraph->AddNode(op_desc_2);
ASSERT_TRUE(node_2 != nullptr);

OpDescPtr op_desc_3 = CreateOpDesc("sub_node_2", "AllReduce2");
ge::AttrUtils::SetStr(op_desc_3, ATTR_NAME_PARALLEL_GROUP, "group_2");
auto node_3 = subgraph->AddNode(op_desc_3);
ASSERT_TRUE(node_3 != nullptr);

ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS);
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 2);
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 2);
ASSERT_EQ(builder.parallel_group_to_nodes_["group_1"].size(), 2);
ASSERT_EQ(builder.parallel_group_to_nodes_["group_2"].size(), 1);

ASSERT_FALSE(node_item->has_observer);
ASSERT_TRUE(node_item_1->dependents_for_execution.empty());
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
ASSERT_TRUE(node_item->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
ASSERT_EQ(node_item_1->dependents_for_execution[0], node);

// repeat parse
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
ASSERT_TRUE(node_item->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
ASSERT_EQ(node_item_1->dependents_for_execution[0], node);
}

Loading…
Cancel
Save