@@ -60,7 +60,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
default: | default: | ||||
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | ||||
" and FORMAT_FRACTAL_NZ is not supported."; | " and FORMAT_FRACTAL_NZ is not supported."; | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -59,7 +59,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
default: | default: | ||||
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | ||||
" and FORMAT_FRACTAL_ZZ is not supported."; | " and FORMAT_FRACTAL_ZZ is not supported."; | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -92,7 +92,8 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
"Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | ||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | return ACL_ERROR_GE_MEMORY_ALLOCATION; | ||||
@@ -50,21 +50,21 @@ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | |||||
bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | ||||
if (src_shape.empty()) { | if (src_shape.empty()) { | ||||
std::string error = "Failed to transpose, empty src shape"; | std::string error = "Failed to transpose, empty src shape"; | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape"); | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to transpose, empty src shape"); | |||||
return false; | return false; | ||||
} | } | ||||
for (auto dim : src_shape) { | for (auto dim : src_shape) { | ||||
if (dim < 0) { | if (dim < 0) { | ||||
std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape)); | std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape)); | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
if (perm_arg.size() != src_shape.size()) { | if (perm_arg.size() != src_shape.size()) { | ||||
std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) + | std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) + | ||||
" and perm arg" + FmtToStr(perm_arg.size()) + " are different"; | " and perm arg" + FmtToStr(perm_arg.size()) + " are different"; | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -73,7 +73,7 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in | |||||
if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) { | if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) { | ||||
std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) + | std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) + | ||||
", perm arg " + FmtToStr(JoinToString(perm_arg)); | ", perm arg " + FmtToStr(JoinToString(perm_arg)); | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -82,11 +82,11 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in | |||||
bool IsTransposeArgValid(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type, | bool IsTransposeArgValid(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type, | ||||
const std::vector<int64_t> &perm_arg) { | const std::vector<int64_t> &perm_arg) { | ||||
if (src == nullptr) { | if (src == nullptr) { | ||||
GELOGE(PARAM_INVALID, "Failed to transpose, the src is null"); | |||||
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to transpose, the src is null"); | |||||
return false; | return false; | ||||
} | } | ||||
if (GetSizeByDataType(src_data_type) < 0) { | if (GetSizeByDataType(src_data_type) < 0) { | ||||
GELOGE(UNSUPPORTED, "Failed to transpose, the data type %s is not support", | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to transpose, the data type %s is not support", | |||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -136,12 +136,12 @@ Status HybridModelBuilder::Build() { | |||||
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName()); | GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName()); | ||||
GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName()); | GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName()); | ||||
GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName()); | GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName()); | ||||
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); | |||||
GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName()); | GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName()); | ||||
GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName()); | GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName()); | ||||
GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName()); | GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName()); | ||||
GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName()); | GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName()); | ||||
GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName()); | GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName()); | ||||
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); | |||||
GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName()); | GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName()); | ||||
GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName()); | GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName()); | ||||
GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName()); | GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName()); | ||||
@@ -599,9 +599,9 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { | |||||
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { | |||||
merged_graph = MakeShared<ComputeGraph>("MergedGraph"); | merged_graph = MakeShared<ComputeGraph>("MergedGraph"); | ||||
for (const auto &node : root_graph.GetDirectNode()) { | |||||
for (const auto &node : root_graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
@@ -631,7 +631,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||||
} | } | ||||
} | } | ||||
} | } | ||||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), | |||||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), | |||||
"[%s] Failed to merge subgraph.", | "[%s] Failed to merge subgraph.", | ||||
subgraph->GetName().c_str()); | subgraph->GetName().c_str()); | ||||
} | } | ||||
@@ -647,18 +647,19 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||||
return a_level < b_level; | return a_level < b_level; | ||||
}); | }); | ||||
for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { | |||||
for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) { | |||||
GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); | GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); | ||||
GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), | GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), | ||||
"Failed to add subgraph [%s]", | "Failed to add subgraph [%s]", | ||||
remained_subgraph->GetName().c_str()); | remained_subgraph->GetName().c_str()); | ||||
remained_subgraph->SetParentGraph(merged_graph); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, | |||||
ComputeGraph &parent_graph, | |||||
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, | |||||
ComputeGraphPtr &parent_graph, | |||||
ComputeGraph &sub_graph) { | ComputeGraph &sub_graph) { | ||||
auto parent_node = sub_graph.GetParentNode(); | auto parent_node = sub_graph.GetParentNode(); | ||||
GE_CHECK_NOTNULL(parent_node); | GE_CHECK_NOTNULL(parent_node); | ||||
@@ -687,15 +688,23 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, | |||||
} | } | ||||
} | } | ||||
parent_graph.AddNode(sub_node); | |||||
if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { | |||||
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { | |||||
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); | |||||
GE_CHECK_NOTNULL(sub_sub_graph); | |||||
sub_sub_graph->SetParentGraph(parent_graph); | |||||
} | |||||
} | |||||
parent_graph->AddNode(sub_node); | |||||
GELOGD("[%s::%s] added to parent graph: [%s].", | GELOGD("[%s::%s] added to parent graph: [%s].", | ||||
sub_graph.GetName().c_str(), | sub_graph.GetName().c_str(), | ||||
sub_node->GetName().c_str(), | sub_node->GetName().c_str(), | ||||
parent_graph.GetName().c_str()); | |||||
parent_graph->GetName().c_str()); | |||||
sub_node->SetOwnerComputeGraph(parent_graph); | |||||
} | } | ||||
GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); | GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); | ||||
root_graph.RemoveSubgraph(sub_graph.GetName()); | |||||
root_graph->RemoveSubgraph(sub_graph.GetName()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -747,7 +756,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | ||||
root_graph->GetDirectNodesSize(), | root_graph->GetDirectNodesSize(), | ||||
root_graph->GetAllNodesSize()); | root_graph->GetAllNodesSize()); | ||||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); | |||||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs."); | |||||
root_graph = std::move(merged_graph); | root_graph = std::move(merged_graph); | ||||
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | ||||
root_graph->GetDirectNodesSize(), | root_graph->GetDirectNodesSize(), | ||||
@@ -1030,8 +1039,8 @@ Status HybridModelBuilder::InitWeights() { | |||||
GELOGI("Init weight mem successfully, weight base %p, weight size = %zu", | GELOGI("Init weight mem successfully, weight base %p, weight size = %zu", | ||||
weight_base, | weight_base, | ||||
sub_weight_buffer->GetSize()); | sub_weight_buffer->GetSize()); | ||||
auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); | |||||
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); | |||||
auto root_graph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); | |||||
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(), std::move(sub_weight_buffer)); | |||||
for (auto &node : root_graph->GetDirectNode()) { | for (auto &node : root_graph->GetDirectNode()) { | ||||
if (node->GetType() != CONSTANT) { | if (node->GetType() != CONSTANT) { | ||||
continue; | continue; | ||||
@@ -47,8 +47,8 @@ class HybridModelBuilder { | |||||
static Status HandleDtString(const GeTensor &tensor, void *var_addr); | static Status HandleDtString(const GeTensor &tensor, void *var_addr); | ||||
static Status MergeInputNodes(ComputeGraph &compute_graph); | static Status MergeInputNodes(ComputeGraph &compute_graph); | ||||
static Status MergeNetOutputNode(ComputeGraph &compute_graph); | static Status MergeNetOutputNode(ComputeGraph &compute_graph); | ||||
static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); | |||||
static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph); | |||||
static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); | |||||
static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); | |||||
static Status BuildInputMapping(GraphItem &graph_item, | static Status BuildInputMapping(GraphItem &graph_item, | ||||
std::vector<NodeItem *> &data_nodes, | std::vector<NodeItem *> &data_nodes, | ||||
bool is_root_graph); | bool is_root_graph); | ||||
@@ -1 +1 @@ | |||||
Subproject commit ccfccb4bb355425cc09594b8ea267fb8ca938138 | |||||
Subproject commit 7e90824d05f349c77b85c5d547b80f9f7e197e35 |
@@ -1 +1 @@ | |||||
Subproject commit 0d4703aa893e90f23ba8a2dbd8903e028680213f | |||||
Subproject commit 0b1cd5d98d1f80c119c4aa251216d837f9f7c359 |
@@ -4676,5 +4676,24 @@ TEST_F(UtestFormatTranspose, invalid_dst_format) { | |||||
EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, src_shape, DT_FLOAT16, FORMAT_C1HWNC0, dst_shape), | EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, src_shape, DT_FLOAT16, FORMAT_C1HWNC0, dst_shape), | ||||
ACL_ERROR_GE_FORMAT_INVALID); | ACL_ERROR_GE_FORMAT_INVALID); | ||||
} | } | ||||
TEST_F(UtestFormatTranspose, invalid_src_data) { | |||||
uint8_t *data = nullptr; | |||||
TransArgs args{data, FORMAT_NCHW, FORMAT_NHWC, std::vector<int64_t>({1, 3, 8, 8}), std::vector<int64_t>({1, 8, 8, 3}), DT_INT64}; | |||||
FormatTransferTranspose transpose; | |||||
TransResult result; | |||||
EXPECT_EQ(transpose.TransFormat(args, result), ACL_ERROR_GE_PARAM_INVALID); | |||||
uint16_t data1[3] = {14583, 12849, 14184}; | |||||
TransArgs args1{reinterpret_cast<uint8_t *>(data1), FORMAT_NCHW, FORMAT_NHWC, std::vector<int64_t>({-1, 3, 1, 1}), std::vector<int64_t>({1, 1, 1, 3}), DT_INT64}; | |||||
FormatTransferTranspose transpose1; | |||||
TransResult result1; | |||||
EXPECT_EQ(transpose1.TransFormat(args1, result1), ACL_ERROR_GE_SHAPE_INVALID); | |||||
TransArgs args2{reinterpret_cast<uint8_t *>(data1), FORMAT_NCHW, FORMAT_NHWC, std::vector<int64_t>({3, 1, 1}), std::vector<int64_t>({1, 1, 1, 3}), DT_INT64}; | |||||
FormatTransferTranspose transpose2; | |||||
TransResult result2; | |||||
EXPECT_EQ(transpose2.TransFormat(args2, result2), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | |||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -332,4 +332,54 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) { | |||||
ASSERT_TRUE(model.node_items_[node]->has_observer); | ASSERT_TRUE(model.node_items_[node]->has_observer); | ||||
ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 1); | ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 1); | ||||
ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution[0], node); | ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution[0], node); | ||||
} | |||||
} | |||||
TEST_F(UtestGeHybrid, unfold_subgraphs_success) { | |||||
ComputeGraphPtr merged_graph = nullptr; | |||||
ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond"); | |||||
OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); | |||||
NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); | |||||
ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body"); | |||||
/*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); | |||||
NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ | |||||
OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); | |||||
NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); | |||||
sub_sub_graph2->SetGraphUnknownFlag(true); | |||||
/*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); | |||||
NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); | |||||
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); | |||||
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ | |||||
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); | |||||
OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); | |||||
NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); | |||||
sub_graph->SetGraphUnknownFlag(true); | |||||
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); | |||||
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); | |||||
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); | |||||
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); | |||||
ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph"); | |||||
auto partitioned_call_op_desc = MakeShared<OpDesc>("partitioned_call", PARTITIONEDCALL); | |||||
auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); | |||||
partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); | |||||
partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); | |||||
root_graph->AddSubGraph(sub_sub_graph1); | |||||
root_graph->AddSubGraph(sub_sub_graph2); | |||||
sub_sub_graph1->SetParentGraph(root_graph); | |||||
sub_sub_graph2->SetParentGraph(root_graph); | |||||
sub_sub_graph1->SetParentNode(sub_graph_while_node); | |||||
sub_sub_graph2->SetParentNode(sub_graph_while_node); | |||||
root_graph->AddSubGraph(sub_graph); | |||||
sub_graph->SetParentNode(partitioned_call_node); | |||||
sub_graph->SetParentGraph(root_graph); | |||||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph); | |||||
HybridModel hybrid_model(root_model); | |||||
HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); | |||||
} |