From ba39c5375b91b6a882f920310340a619a652f667 Mon Sep 17 00:00:00 2001 From: lichun Date: Sun, 28 Mar 2021 17:37:01 +0800 Subject: [PATCH] support unknown while subgraph --- .../format_transfer_fractal_nz.cc | 2 +- .../format_transfer_fractal_zz.cc | 2 +- .../format_transfer_nhwc_nc1hwc0.cc | 3 +- .../format_transfer_transpose.cc | 14 ++--- ge/hybrid/model/hybrid_model_builder.cc | 35 ++++++++----- ge/hybrid/model/hybrid_model_builder.h | 4 +- metadef | 2 +- parser | 2 +- .../format_transfer_transpose_unittest.cc | 19 +++++++ tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 52 ++++++++++++++++++- 10 files changed, 107 insertions(+), 28 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc index fccdb57b..01c7de95 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -60,7 +60,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { default: std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + " and FORMAT_FRACTAL_NZ is not supported."; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return false; } } diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc index c36bffb5..36bea872 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc @@ -59,7 +59,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { default: std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + " and FORMAT_FRACTAL_ZZ is not supported."; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return false; } } diff --git a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc index b09fd168..6817713a 100755 --- a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc +++ b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc @@ -92,7 +92,8 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { - GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, + "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); return ACL_ERROR_GE_MEMORY_ALLOCATION; diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.cc b/ge/common/formats/format_transfers/format_transfer_transpose.cc index 694777f3..49bb5cd6 100755 --- a/ge/common/formats/format_transfers/format_transfer_transpose.cc +++ b/ge/common/formats/format_transfers/format_transfer_transpose.cc @@ -50,21 +50,21 @@ std::map>> perm_args{ bool IsShapeArgValid(const std::vector &src_shape, const std::vector &perm_arg) { if (src_shape.empty()) { std::string error = "Failed to transpose, empty src shape"; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); - GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape"); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); + GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to transpose, empty src shape"); return false; } for (auto dim : src_shape) { if (dim < 0) { std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape)); - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); return false; } } if (perm_arg.size() != src_shape.size()) { std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) + " and perm arg" + FmtToStr(perm_arg.size()) + " are different"; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); return false; } @@ -73,7 +73,7 @@ bool IsShapeArgValid(const std::vector &src_shape, const std::vector(perm) >= perm_arg.size() || ++exists[perm] > 1) { std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) + ", perm arg " + FmtToStr(JoinToString(perm_arg)); - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); return false; } } @@ -82,11 +82,11 @@ bool IsShapeArgValid(const std::vector &src_shape, const std::vector &src_shape, DataType src_data_type, const std::vector &perm_arg) { if (src == nullptr) { - GELOGE(PARAM_INVALID, "Failed to transpose, the src is null"); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to transpose, the src is null"); return false; } if (GetSizeByDataType(src_data_type) < 0) { - GELOGE(UNSUPPORTED, "Failed to transpose, the data type %s is not support", + GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to transpose, the data type %s is not support", TypeUtils::DataTypeToSerialString(src_data_type).c_str()); return false; } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 25dabd78..34224fe5 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -136,12 +136,12 @@ Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName()); GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName()); GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName()); + GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName()); GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName()); GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName()); GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName()); GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName()); - GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName()); GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName()); GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName()); @@ -599,9 +599,9 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { +Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { merged_graph = MakeShared("MergedGraph"); - for (const auto &node : root_graph.GetDirectNode()) { + for (const auto &node : root_graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -631,7 +631,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap } } } - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), "[%s] Failed to merge subgraph.", subgraph->GetName().c_str()); } @@ -647,18 +647,19 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap return a_level < b_level; }); - for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { + for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) { GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), "Failed to add subgraph [%s]", remained_subgraph->GetName().c_str()); + remained_subgraph->SetParentGraph(merged_graph); } return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, - ComputeGraph &parent_graph, +Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, + ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph) { auto parent_node = sub_graph.GetParentNode(); GE_CHECK_NOTNULL(parent_node); @@ -687,15 +688,23 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, } } - parent_graph.AddNode(sub_node); + if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { + for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { + auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); + GE_CHECK_NOTNULL(sub_sub_graph); + sub_sub_graph->SetParentGraph(parent_graph); + } + } + parent_graph->AddNode(sub_node); GELOGD("[%s::%s] added to parent graph: [%s].", sub_graph.GetName().c_str(), sub_node->GetName().c_str(), - parent_graph.GetName().c_str()); + parent_graph->GetName().c_str()); + sub_node->SetOwnerComputeGraph(parent_graph); } GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); - root_graph.RemoveSubgraph(sub_graph.GetName()); + root_graph->RemoveSubgraph(sub_graph.GetName()); return SUCCESS; } @@ -747,7 +756,7 @@ Status HybridModelBuilder::LoadGraph() { GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs."); root_graph = std::move(merged_graph); GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), @@ -1030,8 +1039,8 @@ Status HybridModelBuilder::InitWeights() { GELOGI("Init weight mem successfully, weight base %p, weight size = %zu", weight_base, sub_weight_buffer->GetSize()); - auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); - hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); + auto root_graph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); + hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(), std::move(sub_weight_buffer)); for (auto &node : root_graph->GetDirectNode()) { if (node->GetType() != CONSTANT) { continue; diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index a59a282a..30241003 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -47,8 +47,8 @@ class HybridModelBuilder { static Status HandleDtString(const GeTensor &tensor, void *var_addr); static Status MergeInputNodes(ComputeGraph &compute_graph); static Status MergeNetOutputNode(ComputeGraph &compute_graph); - static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); - static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph); + static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); + static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); static Status BuildInputMapping(GraphItem &graph_item, std::vector &data_nodes, bool is_root_graph); diff --git a/metadef b/metadef index ccfccb4b..7e90824d 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit ccfccb4bb355425cc09594b8ea267fb8ca938138 +Subproject commit 7e90824d05f349c77b85c5d547b80f9f7e197e35 diff --git a/parser b/parser index 0d4703aa..0b1cd5d9 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit 0d4703aa893e90f23ba8a2dbd8903e028680213f +Subproject commit 0b1cd5d98d1f80c119c4aa251216d837f9f7c359 diff --git a/tests/ut/ge/common/format_transfer_transpose_unittest.cc b/tests/ut/ge/common/format_transfer_transpose_unittest.cc index 04f2a557..b710acde 100644 --- a/tests/ut/ge/common/format_transfer_transpose_unittest.cc +++ b/tests/ut/ge/common/format_transfer_transpose_unittest.cc @@ -4676,5 +4676,24 @@ TEST_F(UtestFormatTranspose, invalid_dst_format) { EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, src_shape, DT_FLOAT16, FORMAT_C1HWNC0, dst_shape), ACL_ERROR_GE_FORMAT_INVALID); } + +TEST_F(UtestFormatTranspose, invalid_src_data) { + uint8_t *data = nullptr; + TransArgs args{data, FORMAT_NCHW, FORMAT_NHWC, std::vector({1, 3, 8, 8}), std::vector({1, 8, 8, 3}), DT_INT64}; + FormatTransferTranspose transpose; + TransResult result; + EXPECT_EQ(transpose.TransFormat(args, result), ACL_ERROR_GE_PARAM_INVALID); + + uint16_t data1[3] = {14583, 12849, 14184}; + TransArgs args1{reinterpret_cast(data1), FORMAT_NCHW, FORMAT_NHWC, std::vector({-1, 3, 1, 1}), std::vector({1, 1, 1, 3}), DT_INT64}; + FormatTransferTranspose transpose1; + TransResult result1; + EXPECT_EQ(transpose1.TransFormat(args1, result1), ACL_ERROR_GE_SHAPE_INVALID); + + TransArgs args2{reinterpret_cast(data1), FORMAT_NCHW, FORMAT_NHWC, std::vector({3, 1, 1}), std::vector({1, 1, 1, 3}), DT_INT64}; + FormatTransferTranspose transpose2; + TransResult result2; + EXPECT_EQ(transpose2.TransFormat(args2, result2), ACL_ERROR_GE_SHAPE_INVALID); +} } // namespace formats } // namespace ge diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index f38037a0..8c4517c7 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -332,4 +332,54 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) { ASSERT_TRUE(model.node_items_[node]->has_observer); ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 1); ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution[0], node); -} \ No newline at end of file +} + +TEST_F(UtestGeHybrid, unfold_subgraphs_success) { + ComputeGraphPtr merged_graph = nullptr; + + ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); + OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); + NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); + + ComputeGraphPtr sub_sub_graph2 = std::make_shared("while_body"); + /*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); + NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ + OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); + NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); + sub_sub_graph2->SetGraphUnknownFlag(true); + /*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); + NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); + sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); + sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ + + ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); + OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); + NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); + sub_graph->SetGraphUnknownFlag(true); + sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); + sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); + sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); + sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); + + ComputeGraphPtr root_graph = std::make_shared("root_graph"); + auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); + auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); + partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); + partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); + + root_graph->AddSubGraph(sub_sub_graph1); + root_graph->AddSubGraph(sub_sub_graph2); + sub_sub_graph1->SetParentGraph(root_graph); + sub_sub_graph2->SetParentGraph(root_graph); + sub_sub_graph1->SetParentNode(sub_graph_while_node); + sub_sub_graph2->SetParentNode(sub_graph_while_node); + + root_graph->AddSubGraph(sub_graph); + sub_graph->SetParentNode(partitioned_call_node); + sub_graph->SetParentGraph(root_graph); + + GeRootModelPtr root_model = MakeShared(root_graph); + HybridModel hybrid_model(root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); +}