Browse Source

Bugfix: fix null owner graph error

tags/v1.3.0
lichun 3 years ago
parent
commit
7a3dba72af
5 changed files with 5 additions and 5 deletions
  1. +1
    -0
      ge/hybrid/model/hybrid_model.h
  2. +2
    -2
      ge/hybrid/model/hybrid_model_builder.cc
  3. +0
    -1
      ge/hybrid/model/hybrid_model_builder.h
  4. +1
    -1
      metadef
  5. +1
    -1
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 1
- 0
ge/hybrid/model/hybrid_model.h View File

@@ -135,6 +135,7 @@ class HybridModel {
std::string model_name_;
GeRootModelPtr ge_root_model_;
std::map<uint32_t, NodeItem *> input_nodes_;
ComputeGraphPtr root_graph_;
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_;


+ 2
- 2
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -764,7 +764,7 @@ Status HybridModelBuilder::LoadGraph() {
root_graph->GetAllNodesSize());
}

root_graph_ = root_graph;
hybrid_model_.root_graph_ = root_graph;
// Reset node id by topological order across all subgraphs
int64_t index = 0;
for (const auto &node : root_graph->GetAllNodes()) {
@@ -2058,7 +2058,7 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
GELOGD("[%s] Start to get parallel group from subgraph: %s",
node_item->NodeName().c_str(),
subgraph_name.c_str());
auto subgraph = root_graph_->GetSubgraph(subgraph_name);
auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name);
GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetAllNodes()) {
std::string parallel_group;


+ 0
- 1
ge/hybrid/model/hybrid_model_builder.h View File

@@ -100,7 +100,6 @@ class HybridModelBuilder {
NodeItem *MutableNodeItem(const NodePtr &node);

GeRootModelPtr ge_root_model_;
ComputeGraphPtr root_graph_;
std::map<std::string, GeModelPtr> subgraph_models_;
std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 86781b7e8ce21d2b901406cc3619d6bea2aeb18e
Subproject commit 4ff5e3987f2e5d2980019defacaf0891861c84fc

+ 1
- 1
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -276,9 +276,9 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) {
op_desc->SetOpKernelLibName("ops_kernel_info_hccl");
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
HybridModel model(root_model);
model.root_graph_ = compute_graph;

HybridModelBuilder builder(model);
builder.root_graph_ = compute_graph;
ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS);

ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);


Loading…
Cancel
Save