From: @zhao_zhixuan Reviewed-by: Signed-off-by:tags/v1.3.0
| @@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() { | |||||
| Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
| GELOGD("Start to init HybridGraphEngine."); | GELOGD("Start to init HybridGraphEngine."); | ||||
| GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | ||||
| root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | |||||
| GE_CHECK_NOTNULL(root_graph_executor_); | |||||
| GELOGD("HybridGraphEngine initialized successfully."); | GELOGD("HybridGraphEngine initialized successfully."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | ||||
| sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | ||||
| } | } | ||||
| SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | |||||
| auto ret = ExecuteGraphInternal(executor, args); | |||||
| auto ret = ExecuteGraphInternal(args); | |||||
| Cleanup(); | Cleanup(); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | ||||
| GELOGD("Model executed successfully."); | GELOGD("Model executed successfully."); | ||||
| @@ -69,6 +70,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| context_.profiler->Dump(std::cout); | context_.profiler->Dump(std::cout); | ||||
| context_.profiler->Reset(); | context_.profiler->Reset(); | ||||
| } | } | ||||
| root_graph_executor_->ReleaseContext(); | |||||
| context_.iteration += 1; | context_.iteration += 1; | ||||
| if (ret == END_OF_SEQUENCE) { | if (ret == END_OF_SEQUENCE) { | ||||
| @@ -79,8 +81,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| HybridModelExecutor::ExecuteArgs &args) { | |||||
| Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | ||||
| GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | ||||
| @@ -94,7 +95,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | ||||
| } | } | ||||
| HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||||
| HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||||
| "Failed to execute partitioned call."); | "Failed to execute partitioned call."); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | ||||
| @@ -103,7 +104,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| } | } | ||||
| if (!model_->IsSingleOp()) { | if (!model_->IsSingleOp()) { | ||||
| Status ret = executor.Synchronize(); | |||||
| Status ret = root_graph_executor_->Synchronize(); | |||||
| if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
| auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| @@ -123,7 +124,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| } | } | ||||
| args.outputs.clear(); | args.outputs.clear(); | ||||
| HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||||
| HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -48,7 +48,7 @@ class HybridModelExecutor { | |||||
| Status Execute(ExecuteArgs &args); | Status Execute(ExecuteArgs &args); | ||||
| private: | private: | ||||
| Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); | |||||
| Status ExecuteGraphInternal(ExecuteArgs &args); | |||||
| Status Cleanup(); | Status Cleanup(); | ||||
| Status InitExecutionContext(); | Status InitExecutionContext(); | ||||
| static Status ResetExecutionContext(GraphExecutionContext &context); | static Status ResetExecutionContext(GraphExecutionContext &context); | ||||
| @@ -58,6 +58,7 @@ class HybridModelExecutor { | |||||
| uint32_t device_id_; | uint32_t device_id_; | ||||
| rtStream_t stream_; | rtStream_t stream_; | ||||
| GraphExecutionContext context_; | GraphExecutionContext context_; | ||||
| std::unique_ptr<SubgraphExecutor> root_graph_executor_; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -177,6 +177,10 @@ struct NodeState { | |||||
| void SetTaskContext(std::shared_ptr<TaskContext> &task_context); | void SetTaskContext(std::shared_ptr<TaskContext> &task_context); | ||||
| std::shared_ptr<TaskContext> GetTaskContext(); | std::shared_ptr<TaskContext> GetTaskContext(); | ||||
| void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; } | |||||
| bool MaySkipShapeInference() const { return skip_infershape_; } | |||||
| private: | private: | ||||
| bool IsScheduleReady() const; | bool IsScheduleReady() const; | ||||
| void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | ||||
| @@ -204,6 +208,7 @@ struct NodeState { | |||||
| int merge_index_ = -1; // Use for Execute (Reset after Executed). | int merge_index_ = -1; // Use for Execute (Reset after Executed). | ||||
| int switch_index_ = -1; // Use for Schedule (Reset after Prepared). | int switch_index_ = -1; // Use for Schedule (Reset after Prepared). | ||||
| int group_ = -1; | int group_ = -1; | ||||
| bool skip_infershape_ = false; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -103,6 +103,14 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue | |||||
| auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | ||||
| GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
| node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); | node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); | ||||
| auto op_desc = input_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| output_desc->SetShape(tensor_desc->GetShape()); | |||||
| output_desc->SetOriginShape(tensor_desc->GetOriginShape()); | |||||
| output_desc->SetDataType(tensor_desc->GetDataType()); | |||||
| node_state->SetSkipInferShape(true); | |||||
| } | } | ||||
| } | } | ||||
| @@ -41,6 +41,8 @@ class SubgraphExecutor { | |||||
| Status PartialExecuteAsync(int task_group); | Status PartialExecuteAsync(int task_group); | ||||
| void ReleaseContext() { subgraph_context_.reset(nullptr); } | |||||
| /** | /** | ||||
| * Execute subgraph async, output tensor address(not data) and output tensor descriptions are | * Execute subgraph async, output tensor address(not data) and output tensor descriptions are | ||||
| * valid after this method returned | * valid after this method returned | ||||
| @@ -68,8 +68,9 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| } | } | ||||
| // Do shape inference | // Do shape inference | ||||
| // Skipping infer shape of input node. | |||||
| GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); | GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); | ||||
| { | |||||
| if (!node_state.MaySkipShapeInference()) { | |||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | ||||
| GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | ||||
| "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); | "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); | ||||
| @@ -44,29 +44,63 @@ using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const size_t kDataOutputNum = 1; | const size_t kDataOutputNum = 1; | ||||
| const uint32_t kInputIndexOfData = 0; | |||||
| const uint32_t kOutputIndexOfData = 0; | const uint32_t kOutputIndexOfData = 0; | ||||
| constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | ||||
| Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | |||||
| Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &is_host_mem) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| for (const auto &input_name : dependencies) { | |||||
| int input_index = op_desc->GetInputIndexByName(input_name); | |||||
| if (input_index < 0) { | |||||
| GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", | |||||
| node->GetName().c_str(), input_name.c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", | |||||
| node->GetName().c_str(), input_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| const auto &src_node = NodeUtils::GetInDataNodeByIndex(*node, input_index); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto src_op_desc = src_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(src_op_desc); | |||||
| if (src_op_desc->GetType() == DATA) { | |||||
| auto tensor = src_op_desc->MutableInputDesc(kInputIndexOfData); | |||||
| if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { | |||||
| GELOGD("Get hostmem from node %s, inputname: %s.", src_node->GetName().c_str(), input_name.c_str()); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| is_host_mem = false; | |||||
| return SUCCESS; | |||||
| } | |||||
| is_host_mem = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_host_mem) { | |||||
| auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | ||||
| GE_CHECK_NOTNULL(comp_graph); | GE_CHECK_NOTNULL(comp_graph); | ||||
| for (const auto &node : comp_graph->GetAllNodes()) { | for (const auto &node : comp_graph->GetAllNodes()) { | ||||
| GE_CHECK_NOTNULL(node); | |||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| const auto &depends = op_desc->GetOpInferDepends(); | const auto &depends = op_desc->GetOpInferDepends(); | ||||
| bool support_dynamic_shape = false; | bool support_dynamic_shape = false; | ||||
| (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | ||||
| if (!depends.empty() && support_dynamic_shape) { | if (!depends.empty() && support_dynamic_shape) { | ||||
| flag = true; | |||||
| return SUCCESS; | |||||
| is_infer_depend = true; | |||||
| return CheckHostMem(depends, node, is_host_mem); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | ||||
| bool infer_depend_flag = false; | |||||
| GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); | |||||
| bool is_infer_depend = false; | |||||
| bool is_host_mem = false; | |||||
| GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); | |||||
| bool need_d2h_cpy = is_infer_depend && !is_host_mem; | |||||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | auto tasks = ge_model->GetModelTaskDefPtr()->task(); | ||||
| int32_t kernel_task_num = 0; | int32_t kernel_task_num = 0; | ||||
| for (int i = 0; i < tasks.size(); ++i) { | for (int i = 0; i < tasks.size(); ++i) { | ||||
| @@ -76,7 +110,7 @@ Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | |||||
| tasks[i].kernel_with_handle().context(); | tasks[i].kernel_with_handle().context(); | ||||
| auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | ||||
| if (kernel_type == ccKernelType::TE) { | if (kernel_type == ccKernelType::TE) { | ||||
| if (infer_depend_flag) { | |||||
| if (need_d2h_cpy) { | |||||
| flag = true; | flag = true; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -517,7 +551,8 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { | |||||
| auto ge_model = model_helper_.GetGeModel(); | auto ge_model = model_helper_.GetGeModel(); | ||||
| GE_CHECK_NOTNULL(ge_model); | GE_CHECK_NOTNULL(ge_model); | ||||
| bool infer_depend_flag = false; | bool infer_depend_flag = false; | ||||
| GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); | |||||
| bool is_host_mem = false; | |||||
| GE_CHK_STATUS_RET(CheckInferDepend(ge_model, infer_depend_flag, is_host_mem), "[Check][InferDepend] failed."); | |||||
| if (infer_depend_flag) { | if (infer_depend_flag) { | ||||
| // construct single_op, do single op with HybridModelExecutor | // construct single_op, do single op with HybridModelExecutor | ||||
| GELOGD("Init hybrid model params of single op, and will do execute with hybrid model executor."); | GELOGD("Init hybrid model params of single op, and will do execute with hybrid model executor."); | ||||
| @@ -87,21 +87,20 @@ TEST_F(UtestHybridModelAsyncExecutor, BuildDeviceTensor) { | |||||
| ASSERT_EQ(size, 100); | ASSERT_EQ(size, 100); | ||||
| } | } | ||||
| TEST_F(UtestHybridModelAsyncExecutor, Test_execute_internal) { | |||||
| TEST_F(UtestHybridModelAsyncExecutor, Test_execute) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | ||||
| ge_root_model->SetModelName("test_name"); | ge_root_model->SetModelName("test_name"); | ||||
| HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
| hybrid_model.root_graph_item_.reset(new GraphItem); | |||||
| HybridModelExecutor executor(&hybrid_model, 0, nullptr); | HybridModelExecutor executor(&hybrid_model, 0, nullptr); | ||||
| ASSERT_EQ(executor.Init(), SUCCESS); | ASSERT_EQ(executor.Init(), SUCCESS); | ||||
| auto &context = executor.context_; | auto &context = executor.context_; | ||||
| GraphItem graph_item; | |||||
| SubgraphExecutor subgraph_executor(&graph_item, &context); | |||||
| HybridModelExecutor::ExecuteArgs args; | HybridModelExecutor::ExecuteArgs args; | ||||
| std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry; | std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry; | ||||
| eof_entry.first = nullptr; | eof_entry.first = nullptr; | ||||
| context.callback_manager->callback_queue_.Push(eof_entry); | context.callback_manager->callback_queue_.Push(eof_entry); | ||||
| ASSERT_EQ(executor.ExecuteGraphInternal(subgraph_executor, args), SUCCESS); | |||||
| ASSERT_EQ(executor.Execute(args), SUCCESS); | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -329,6 +329,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | ||||
| GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | ||||
| HybridModel model(root_model); | HybridModel model(root_model); | ||||
| model.root_graph_item_.reset(new GraphItem); | |||||
| HybridModel *model_ptr = &model; | HybridModel *model_ptr = &model; | ||||
| uint32_t device_id = 0; | uint32_t device_id = 0; | ||||
| @@ -17,12 +17,11 @@ | |||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include <vector> | #include <vector> | ||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/load/model_manager/model_utils.h" | #include "graph/load/model_manager/model_utils.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "runtime/rt.h" | #include "runtime/rt.h" | ||||
| #define protected public | |||||
| #define private public | |||||
| #include "single_op/single_op_model.h" | #include "single_op/single_op_model.h" | ||||
| #include "single_op/task/tbe_task_builder.h" | #include "single_op/task/tbe_task_builder.h" | ||||
| #include "single_op/task/rts_kernel_task_builder.h" | #include "single_op/task/rts_kernel_task_builder.h" | ||||
| @@ -30,14 +29,19 @@ | |||||
| #include "framework/common/helper/model_helper.h" | #include "framework/common/helper/model_helper.h" | ||||
| #include "single_op/single_op.h" | #include "single_op/single_op.h" | ||||
| #include "single_op/stream_resource.h" | #include "single_op/stream_resource.h" | ||||
| #include "graph/passes/graph_builder_utils.h" | |||||
| #include "graph/op_desc_impl.h" | |||||
| #undef private | #undef private | ||||
| #undef protected | #undef protected | ||||
| #include "graph/passes/graph_builder_utils.h" | |||||
| using namespace std; | using namespace std; | ||||
| using namespace testing; | using namespace testing; | ||||
| using namespace ge; | using namespace ge; | ||||
| namespace { | |||||
| constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | |||||
| } // namespace | |||||
| class UtestSingleOpModel : public testing::Test { | class UtestSingleOpModel : public testing::Test { | ||||
| protected: | protected: | ||||
| void SetUp() {} | void SetUp() {} | ||||
| @@ -208,12 +212,22 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { | |||||
| model.model_helper_.model_ = ge::MakeShared<ge::GeModel>(); | model.model_helper_.model_ = ge::MakeShared<ge::GeModel>(); | ||||
| // make graph | // make graph | ||||
| auto compute_graph = make_shared<ComputeGraph>("graph"); | |||||
| auto data_op = make_shared<OpDesc>("Data", DATA); | |||||
| auto data_node = compute_graph->AddNode(data_op); | |||||
| ut::GraphBuilder builder = ut::GraphBuilder("graph"); | |||||
| auto data = builder.AddNode("Data", "Data", 1, 1); | |||||
| auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); | |||||
| auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); | |||||
| builder.AddDataEdge(data, 0, transdata, 0); | |||||
| builder.AddDataEdge(transdata, 0, netoutput, 0); | |||||
| auto compute_graph = builder.GetGraph(); | |||||
| auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); | auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); | ||||
| model.model_helper_.model_->SetGraph(graph); | model.model_helper_.model_->SetGraph(graph); | ||||
| auto op_desc = transdata->GetOpDesc(); | |||||
| const vector<string> depend_names = { "Data" }; | |||||
| op_desc->SetOpInferDepends(depend_names); | |||||
| (void)AttrUtils::SetBool(op_desc, kAttrSupportDynamicShape, true); | |||||
| // set task_def | // set task_def | ||||
| auto model_task_def = make_shared<domi::ModelTaskDef>(); | auto model_task_def = make_shared<domi::ModelTaskDef>(); | ||||
| domi::TaskDef *task_def = model_task_def->add_task(); | domi::TaskDef *task_def = model_task_def->add_task(); | ||||
| @@ -227,6 +241,15 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { | |||||
| DynamicSingleOp dynamic_single_op(0, &stream_mu_, nullptr); | DynamicSingleOp dynamic_single_op(0, &stream_mu_, nullptr); | ||||
| StreamResource res((uintptr_t)1); | StreamResource res((uintptr_t)1); | ||||
| model.BuildDynamicOp(res, dynamic_single_op); | model.BuildDynamicOp(res, dynamic_single_op); | ||||
| op_desc->impl_->input_name_idx_["Data"] = 0; | |||||
| model.BuildDynamicOp(res, dynamic_single_op); | |||||
| auto tensor = std::make_shared<GeTensor>(); | |||||
| auto data_desc = data->GetOpDesc(); | |||||
| auto tensor_desc = data_desc->MutableInputDesc(0); | |||||
| AttrUtils::SetTensor(tensor_desc, "_value", tensor); | |||||
| model.BuildDynamicOp(res, dynamic_single_op); | |||||
| } | } | ||||
| TEST_F(UtestSingleOpModel, test_host_mem) { | TEST_F(UtestSingleOpModel, test_host_mem) { | ||||