Browse Source

support unknown while subgraph

tags/v1.3.0
lichun 3 years ago
parent
commit
ba39c5375b
10 changed files with 107 additions and 28 deletions
  1. +1
    -1
      ge/common/formats/format_transfers/format_transfer_fractal_nz.cc
  2. +1
    -1
      ge/common/formats/format_transfers/format_transfer_fractal_zz.cc
  3. +2
    -1
      ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc
  4. +7
    -7
      ge/common/formats/format_transfers/format_transfer_transpose.cc
  5. +22
    -13
      ge/hybrid/model/hybrid_model_builder.cc
  6. +2
    -2
      ge/hybrid/model/hybrid_model_builder.h
  7. +1
    -1
      metadef
  8. +1
    -1
      parser
  9. +19
    -0
      tests/ut/ge/common/format_transfer_transpose_unittest.cc
  10. +51
    -1
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 1
- 1
ge/common/formats/format_transfers/format_transfer_fractal_nz.cc View File

@@ -60,7 +60,7 @@ bool CheckShape(Format format, const ShapeVector &shape) {
default:
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) +
" and FORMAT_FRACTAL_NZ is not supported.";
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str());
return false;
}
}


+ 1
- 1
ge/common/formats/format_transfers/format_transfer_fractal_zz.cc View File

@@ -59,7 +59,7 @@ bool CheckShape(Format format, const ShapeVector &shape) {
default:
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) +
" and FORMAT_FRACTAL_ZZ is not supported.";
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str());
return false;
}
}


+ 2
- 1
ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc View File

@@ -92,7 +92,8 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) {
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION,
"Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str());
return ACL_ERROR_GE_MEMORY_ALLOCATION;


+ 7
- 7
ge/common/formats/format_transfers/format_transfer_transpose.cc View File

@@ -50,21 +50,21 @@ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{
bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) {
if (src_shape.empty()) {
std::string error = "Failed to transpose, empty src shape";
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape");
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str());
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to transpose, empty src shape");
return false;
}
for (auto dim : src_shape) {
if (dim < 0) {
std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape));
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str());
return false;
}
}
if (perm_arg.size() != src_shape.size()) {
std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) +
" and perm arg" + FmtToStr(perm_arg.size()) + " are different";
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str());
return false;
}

@@ -73,7 +73,7 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in
if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) {
std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) +
", perm arg " + FmtToStr(JoinToString(perm_arg));
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str());
return false;
}
}
@@ -82,11 +82,11 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in
bool IsTransposeArgValid(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type,
const std::vector<int64_t> &perm_arg) {
if (src == nullptr) {
GELOGE(PARAM_INVALID, "Failed to transpose, the src is null");
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to transpose, the src is null");
return false;
}
if (GetSizeByDataType(src_data_type) < 0) {
GELOGE(UNSUPPORTED, "Failed to transpose, the data type %s is not support",
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to transpose, the data type %s is not support",
TypeUtils::DataTypeToSerialString(src_data_type).c_str());
return false;
}


+ 22
- 13
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -136,12 +136,12 @@ Status HybridModelBuilder::Build() {
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName());
GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName());
GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName());
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName());
GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName());
GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName());
GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName());
GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName());
GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName());
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName());
GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName());
GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName());
GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName());
@@ -599,9 +599,9 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) {
return SUCCESS;
}

Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) {
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) {
merged_graph = MakeShared<ComputeGraph>("MergedGraph");
for (const auto &node : root_graph.GetDirectNode()) {
for (const auto &node : root_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
@@ -631,7 +631,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap
}
}
}
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph),
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph),
"[%s] Failed to merge subgraph.",
subgraph->GetName().c_str());
}
@@ -647,18 +647,19 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap
return a_level < b_level;
});

for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) {
for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) {
GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph),
"Failed to add subgraph [%s]",
remained_subgraph->GetName().c_str());
remained_subgraph->SetParentGraph(merged_graph);
}

return SUCCESS;
}

Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph,
ComputeGraph &parent_graph,
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph,
ComputeGraphPtr &parent_graph,
ComputeGraph &sub_graph) {
auto parent_node = sub_graph.GetParentNode();
GE_CHECK_NOTNULL(parent_node);
@@ -687,15 +688,23 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph,
}
}

parent_graph.AddNode(sub_node);
if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) {
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) {
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i);
GE_CHECK_NOTNULL(sub_sub_graph);
sub_sub_graph->SetParentGraph(parent_graph);
}
}
parent_graph->AddNode(sub_node);
GELOGD("[%s::%s] added to parent graph: [%s].",
sub_graph.GetName().c_str(),
sub_node->GetName().c_str(),
parent_graph.GetName().c_str());
parent_graph->GetName().c_str());
sub_node->SetOwnerComputeGraph(parent_graph);
}

GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str());
root_graph.RemoveSubgraph(sub_graph.GetName());
root_graph->RemoveSubgraph(sub_graph.GetName());
return SUCCESS;
}

@@ -747,7 +756,7 @@ Status HybridModelBuilder::LoadGraph() {
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize());
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs.");
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs.");
root_graph = std::move(merged_graph);
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(),
@@ -1030,8 +1039,8 @@ Status HybridModelBuilder::InitWeights() {
GELOGI("Init weight mem successfully, weight base %p, weight size = %zu",
weight_base,
sub_weight_buffer->GetSize());
auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph());
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer));
auto root_graph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first);
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(), std::move(sub_weight_buffer));
for (auto &node : root_graph->GetDirectNode()) {
if (node->GetType() != CONSTANT) {
continue;


+ 2
- 2
ge/hybrid/model/hybrid_model_builder.h View File

@@ -47,8 +47,8 @@ class HybridModelBuilder {
static Status HandleDtString(const GeTensor &tensor, void *var_addr);
static Status MergeInputNodes(ComputeGraph &compute_graph);
static Status MergeNetOutputNode(ComputeGraph &compute_graph);
static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph);
static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph);
static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph);
static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph);
static Status BuildInputMapping(GraphItem &graph_item,
std::vector<NodeItem *> &data_nodes,
bool is_root_graph);


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit ccfccb4bb355425cc09594b8ea267fb8ca938138
Subproject commit 7e90824d05f349c77b85c5d547b80f9f7e197e35

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit 0d4703aa893e90f23ba8a2dbd8903e028680213f
Subproject commit 0b1cd5d98d1f80c119c4aa251216d837f9f7c359

+ 19
- 0
tests/ut/ge/common/format_transfer_transpose_unittest.cc View File

@@ -4676,5 +4676,24 @@ TEST_F(UtestFormatTranspose, invalid_dst_format) {
EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, src_shape, DT_FLOAT16, FORMAT_C1HWNC0, dst_shape),
ACL_ERROR_GE_FORMAT_INVALID);
}

TEST_F(UtestFormatTranspose, invalid_src_data) {
uint8_t *data = nullptr;
TransArgs args{data, FORMAT_NCHW, FORMAT_NHWC, std::vector<int64_t>({1, 3, 8, 8}), std::vector<int64_t>({1, 8, 8, 3}), DT_INT64};
FormatTransferTranspose transpose;
TransResult result;
EXPECT_EQ(transpose.TransFormat(args, result), ACL_ERROR_GE_PARAM_INVALID);

uint16_t data1[3] = {14583, 12849, 14184};
TransArgs args1{reinterpret_cast<uint8_t *>(data1), FORMAT_NCHW, FORMAT_NHWC, std::vector<int64_t>({-1, 3, 1, 1}), std::vector<int64_t>({1, 1, 1, 3}), DT_INT64};
FormatTransferTranspose transpose1;
TransResult result1;
EXPECT_EQ(transpose1.TransFormat(args1, result1), ACL_ERROR_GE_SHAPE_INVALID);

TransArgs args2{reinterpret_cast<uint8_t *>(data1), FORMAT_NCHW, FORMAT_NHWC, std::vector<int64_t>({3, 1, 1}), std::vector<int64_t>({1, 1, 1, 3}), DT_INT64};
FormatTransferTranspose transpose2;
TransResult result2;
EXPECT_EQ(transpose2.TransFormat(args2, result2), ACL_ERROR_GE_SHAPE_INVALID);
}
} // namespace formats
} // namespace ge

+ 51
- 1
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -332,4 +332,54 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) {
ASSERT_TRUE(model.node_items_[node]->has_observer);
ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 1);
ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution[0], node);
}
}

TEST_F(UtestGeHybrid, unfold_subgraphs_success) {
ComputeGraphPtr merged_graph = nullptr;

ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond");
OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA);
NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc);

ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body");
/*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT);
NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/
OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA);
NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc);
sub_sub_graph2->SetGraphUnknownFlag(true);
/*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD);
NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node);
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node);
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/

ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE);
NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc);
sub_graph->SetGraphUnknownFlag(true);
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond");
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body");
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond");
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body");

ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph");
auto partitioned_call_op_desc = MakeShared<OpDesc>("partitioned_call", PARTITIONEDCALL);
auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc);
partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph");
partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph");

root_graph->AddSubGraph(sub_sub_graph1);
root_graph->AddSubGraph(sub_sub_graph2);
sub_sub_graph1->SetParentGraph(root_graph);
sub_sub_graph2->SetParentGraph(root_graph);
sub_sub_graph1->SetParentNode(sub_graph_while_node);
sub_sub_graph2->SetParentNode(sub_graph_while_node);

root_graph->AddSubGraph(sub_graph);
sub_graph->SetParentNode(partitioned_call_node);
sub_graph->SetParentGraph(root_graph);

GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph);
HybridModel hybrid_model(root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS);
}

Loading…
Cancel
Save