| @@ -145,8 +145,10 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
| if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | ||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | ||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
| return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch | return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | } | ||||
| if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) { | if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) { | ||||
| @@ -41,8 +41,9 @@ bool IsControlFlowV2Op(const std::string &op_type); | |||||
| class OptionalMutexGuard { | class OptionalMutexGuard { | ||||
| public: | public: | ||||
| OptionalMutexGuard(std::mutex *mutex, const string &name); | |||||
| OptionalMutexGuard(std::mutex *mutex, const std::string &name); | |||||
| ~OptionalMutexGuard(); | ~OptionalMutexGuard(); | ||||
| private: | private: | ||||
| std::mutex *mu_{nullptr}; | std::mutex *mu_{nullptr}; | ||||
| std::string name_; | std::string name_; | ||||
| @@ -816,6 +816,7 @@ set(PROFILING_MNG_TEST_FILES | |||||
| set(HYBRID_TEST_FILES | set(HYBRID_TEST_FILES | ||||
| "hybrid/ge_hybrid_unittest.cc" | "hybrid/ge_hybrid_unittest.cc" | ||||
| "hybrid/known_node_executor_unittest.cc" | "hybrid/known_node_executor_unittest.cc" | ||||
| "hybrid/executor/subgraph_executor_unittest.cc" | |||||
| "hybrid/executor/worker/execution_engine_unittest.cc" | "hybrid/executor/worker/execution_engine_unittest.cc" | ||||
| "hybrid/model/hybrid_model_builder_unittest.cc" | "hybrid/model/hybrid_model_builder_unittest.cc" | ||||
| "hybrid/node_executor/rts/rts_node_task_unittest.cc" | "hybrid/node_executor/rts/rts_node_task_unittest.cc" | ||||
| @@ -834,6 +835,8 @@ list(APPEND COMMON_SHARED_LIBRARIES | |||||
| mmpa_stub | mmpa_stub | ||||
| hccl_stub | hccl_stub | ||||
| error_manager_stub | error_manager_stub | ||||
| ascend_protobuf | |||||
| json | |||||
| ) | ) | ||||
| # build graph | # build graph | ||||
| @@ -879,7 +882,7 @@ target_link_libraries(ge_ut_common PRIVATE | |||||
| ) | ) | ||||
| # build common format | # build common format | ||||
| add_library(ge_ut_common_format STATIC ${COMMON_SRC_FILES} ${COMMON_FORMAT_SRC_FILES} ${PROTO_HDRS}) | |||||
| add_library(ge_ut_common_format STATIC ${COMMON_FORMAT_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_ut_common_format PRIVATE | target_compile_definitions(ge_ut_common_format PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| @@ -1056,7 +1059,6 @@ target_link_libraries(ge_single_op PRIVATE | |||||
| # libge_mutiparts_utest | # libge_mutiparts_utest | ||||
| add_executable(ut_libge_multiparts_utest | add_executable(ut_libge_multiparts_utest | ||||
| ${COMMON_TEST_FILES} | ${COMMON_TEST_FILES} | ||||
| ${COMMON_FORMAT_SRC_FILES} | |||||
| ${MULTI_PARTS_TEST_FILES} | ${MULTI_PARTS_TEST_FILES} | ||||
| ) | ) | ||||
| @@ -1071,14 +1073,14 @@ target_compile_definitions(ut_libge_multiparts_utest PRIVATE | |||||
| target_link_libraries(ut_libge_multiparts_utest | target_link_libraries(ut_libge_multiparts_utest | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common ge_single_op ge_ut_common | |||||
| gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov | |||||
| ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common | |||||
| ge_single_op ge_ut_common_format ge_ut_common | |||||
| gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov | |||||
| ) | ) | ||||
| # libge_others_utest | # libge_others_utest | ||||
| add_executable(ut_libge_others_utest | add_executable(ut_libge_others_utest | ||||
| ${COMMON_TEST_FILES} | ${COMMON_TEST_FILES} | ||||
| ${COMMON_FORMAT_SRC_FILES} | |||||
| ${PASS_TEST_FILES} | ${PASS_TEST_FILES} | ||||
| ${EXECUTE_TEST_FILES} | ${EXECUTE_TEST_FILES} | ||||
| ${OTHERS_TEST_FILES} | ${OTHERS_TEST_FILES} | ||||
| @@ -1091,16 +1093,15 @@ target_compile_options(ut_libge_others_utest PRIVATE | |||||
| target_link_libraries(ut_libge_others_utest | target_link_libraries(ut_libge_others_utest | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_load_common ge_execute_common ge_ut_common | |||||
| gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov | |||||
| ge_load_common ge_execute_common ge_ut_common ge_ut_common_format | |||||
| gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov | |||||
| ) | ) | ||||
| # libge_kernel_utest | # libge_kernel_utest | ||||
| add_executable(ut_libge_kernel_utest | add_executable(ut_libge_kernel_utest | ||||
| ${COMMON_TEST_FILES} | |||||
| ${COMMON_FORMAT_SRC_FILES} | |||||
| ${KERNEL_TEST_FILES} | |||||
| ${KERNEL_SRC_FILES} | |||||
| ${COMMON_TEST_FILES} | |||||
| ${KERNEL_TEST_FILES} | |||||
| ${KERNEL_SRC_FILES} | |||||
| ) | ) | ||||
| target_compile_options(ut_libge_kernel_utest PRIVATE | target_compile_options(ut_libge_kernel_utest PRIVATE | ||||
| @@ -1110,8 +1111,8 @@ target_compile_options(ut_libge_kernel_utest PRIVATE | |||||
| target_link_libraries(ut_libge_kernel_utest | target_link_libraries(ut_libge_kernel_utest | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_load_common ge_ut_common | |||||
| gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov | |||||
| ge_load_common ge_ut_common ge_ut_common_format | |||||
| gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov | |||||
| ) | ) | ||||
| # libge_distinct_load_utest | # libge_distinct_load_utest | ||||
| @@ -1137,10 +1138,11 @@ target_compile_definitions(ut_libge_distinct_load_utest PRIVATE | |||||
| ) | ) | ||||
| target_link_libraries(ut_libge_distinct_load_utest | target_link_libraries(ut_libge_distinct_load_utest | ||||
| ${COMMON_SHARED_LIBRARIES} | |||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_execute_common ge_ut_common_format ge_load_common | |||||
| ge_single_op ge_prepare_common | |||||
| ge_optimize_common ge_build_common ge_partition_common ge_ut_common | |||||
| gtest gtest_main gmock gmock_main ascend_protobuf json c_sec -lrt -ldl -lpthread -lgcov | |||||
| -Wl,--whole-archive | |||||
| ge_single_op | |||||
| -Wl,--no-whole-archive | |||||
| ge_execute_common ge_load_common | |||||
| ge_prepare_common ge_optimize_common ge_build_common ge_partition_common ge_ut_common ge_ut_common_format | |||||
| gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lpthread -lgcov | |||||
| ) | ) | ||||
| @@ -115,7 +115,7 @@ TEST_F(UtestGeExecutor, load_data_from_file) { | |||||
| string test_smap = "/tmp/" + std::to_string(getpid()) + "_maps"; | string test_smap = "/tmp/" + std::to_string(getpid()) + "_maps"; | ||||
| string self_smap = "/proc/" + std::to_string(getpid()) + "/maps"; | string self_smap = "/proc/" + std::to_string(getpid()) + "/maps"; | ||||
| string copy_smap = "cp " + self_smap + " " + test_smap; | |||||
| string copy_smap = "cp -f " + self_smap + " " + test_smap; | |||||
| EXPECT_EQ(system(copy_smap.c_str()), 0); | EXPECT_EQ(system(copy_smap.c_str()), 0); | ||||
| ModelData model_data; | ModelData model_data; | ||||
| @@ -91,8 +91,8 @@ TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_release) { | |||||
| // test kernel_ex_task_Release | // test kernel_ex_task_Release | ||||
| TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) { | TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) { | ||||
| DavinciModel model(0, nullptr); | DavinciModel model(0, nullptr); | ||||
| model.runtime_param_.mem_base = (uint8_t *)0x12345; | |||||
| model.runtime_param_.mem_size = 100332000; | |||||
| model.runtime_param_.mem_size = 10240; | |||||
| model.runtime_param_.mem_base = new uint8_t[model.runtime_param_.mem_size]; | |||||
| rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
| rtStreamCreate(&stream, 0); | rtStreamCreate(&stream, 0); | ||||
| @@ -108,19 +108,20 @@ TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) { | |||||
| EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace empty. | EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace empty. | ||||
| model.op_list_[0]->SetWorkspace({100331008}); // offset | |||||
| model.op_list_[0]->SetWorkspace({1008}); // offset | |||||
| model.op_list_[0]->SetWorkspaceBytes({0}); // length | model.op_list_[0]->SetWorkspaceBytes({0}); // length | ||||
| EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is null. | EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is null. | ||||
| model.op_list_[0]->SetWorkspace({100331008}); // offset | |||||
| model.op_list_[0]->SetWorkspace({1208}); // offset | |||||
| model.op_list_[0]->SetWorkspaceBytes({10}); // length | model.op_list_[0]->SetWorkspaceBytes({10}); // length | ||||
| EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is small. | EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is small. | ||||
| model.op_list_[0]->SetWorkspace({100331008}); // offset | |||||
| model.op_list_[0]->SetWorkspace({1308}); // offset | |||||
| model.op_list_[0]->SetWorkspaceBytes({150}); // length | model.op_list_[0]->SetWorkspaceBytes({150}); // length | ||||
| EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), SUCCESS); | EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), SUCCESS); | ||||
| task_def.clear_kernel_ex(); | task_def.clear_kernel_ex(); | ||||
| delete [] model.runtime_param_.mem_base; | |||||
| model.runtime_param_.mem_base = nullptr; | model.runtime_param_.mem_base = nullptr; | ||||
| } | } | ||||
| @@ -418,6 +418,6 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) { | |||||
| vector<InputTensorInfo> inputs; | vector<InputTensorInfo> inputs; | ||||
| inputs.emplace_back(input_tensor); | inputs.emplace_back(input_tensor); | ||||
| auto ret = mm.DataInputTensor(model_id,inputs); | auto ret = mm.DataInputTensor(model_id,inputs); | ||||
| EXPECT_EQ(UNSUPPORTED, ret); | |||||
| EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -0,0 +1,249 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <gmock/gmock.h> | |||||
| #include <vector> | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "hybrid/executor/subgraph_executor.h" | |||||
| #include "hybrid/node_executor/node_executor.h" | |||||
| #include "hybrid/node_executor/rts/rts_node_executor.h" | |||||
| #include "hybrid/node_executor/ge_local/ge_local_node_executor.h" | |||||
| #include "hybrid/model/hybrid_model_builder.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| using namespace hybrid; | |||||
| class UtestSubgraphExecutor : public testing::Test { | |||||
| protected: | |||||
| void SetUp() { | |||||
| NodeExecutorManager::GetInstance().engine_mapping_.clear(); | |||||
| auto &engine_mapping = NodeExecutorManager::GetInstance().engine_mapping_; | |||||
| engine_mapping.emplace("DNN_VM_RTS_OP_STORE", NodeExecutorManager::ExecutorType::RTS); | |||||
| engine_mapping.emplace("DNN_VM_GE_LOCAL_OP_STORE", NodeExecutorManager::ExecutorType::GE_LOCAL); | |||||
| NodeExecutorManager::GetInstance().executors_.clear(); | |||||
| auto &task_executor = NodeExecutorManager::GetInstance().executors_; | |||||
| task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new RtsNodeExecutor())); | |||||
| task_executor.emplace(NodeExecutorManager::ExecutorType::GE_LOCAL, std::unique_ptr<NodeExecutor>(new GeLocalNodeExecutor())); | |||||
| } | |||||
| void TearDown() { | |||||
| NodeExecutorManager::GetInstance().engine_mapping_.clear(); | |||||
| NodeExecutorManager::GetInstance().executors_.clear(); | |||||
| } | |||||
| }; | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| op_desc->SetStreamId(0); | |||||
| static int32_t index = 0; | |||||
| op_desc->SetId(index++); | |||||
| GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
| TensorUtils::SetSize(tensor, 64); | |||||
| vector<int64_t> input_offset; | |||||
| for (int i = 0; i < in_num; i++) { | |||||
| op_desc->AddInputDesc(tensor); | |||||
| input_offset.emplace_back(index * 64 + i * 64); | |||||
| } | |||||
| op_desc->SetInputOffset(input_offset); | |||||
| vector<int64_t> output_offset; | |||||
| for (int i = 0; i < out_num; i++) { | |||||
| op_desc->AddOutputDesc(tensor); | |||||
| output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); | |||||
| } | |||||
| op_desc->SetOutputOffset(output_offset); | |||||
| op_desc->SetWorkspace({}); | |||||
| op_desc->SetWorkspaceBytes({}); | |||||
| op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| static void CreateSimpleCondGraph(ComputeGraph &graph) { | |||||
| /******************************************************************************* | |||||
| * | | |||||
| * Merge | |||||
| * / \. | |||||
| * / \. | |||||
| * / \. | |||||
| * Add Sub | |||||
| * | \ / | | |||||
| * | \ _ / | | |||||
| * | / \ | | |||||
| * | / \ | | |||||
| * Switch Switch | |||||
| * | \ / | | |||||
| * | \ / | | |||||
| * | \ / | | |||||
| * | \ / | | |||||
| * | Less | | |||||
| * | / \ | | |||||
| * | / \ | | |||||
| * Data Data | |||||
| ******************************************************************************/ | |||||
| const auto data0 = CreateNode(graph, "data", DATA, 1, 1); | |||||
| const auto data1 = CreateNode(graph, "data1", DATA, 1, 1); | |||||
| data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
| data1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
| const auto const0 = CreateNode(graph, "const", CONSTANT, 0, 1); | |||||
| const auto const1 = CreateNode(graph, "const1", CONSTANT, 0, 1); | |||||
| const0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
| const1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
| { | |||||
| uint64_t const_value = 0; | |||||
| const auto op_desc = const0->GetOpDesc(); | |||||
| auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||||
| AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | |||||
| } | |||||
| { | |||||
| uint64_t const_value = 1; | |||||
| const auto op_desc = const1->GetOpDesc(); | |||||
| auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||||
| AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | |||||
| } | |||||
| const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); | |||||
| const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | |||||
| const auto switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | |||||
| const auto switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0); | |||||
| AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); | |||||
| AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); | |||||
| const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); | |||||
| const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); | |||||
| const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | |||||
| const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | |||||
| const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); | |||||
| const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); | |||||
| output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
| GraphUtils::AddEdge(data0->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(const0->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(const1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(less1->GetOutControlAnchor(), active1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_f->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(data0->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(add1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), add1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(add1->GetOutControlAnchor(), active2->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(active2->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(data0->GetOutDataAnchor(0), sub1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), sub1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(sub1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), sub1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
| } | |||||
| TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| const auto data0 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
| const auto output0 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
| GraphUtils::AddEdge(data0->GetOutDataAnchor(0), output0->GetInDataAnchor(0)); | |||||
| data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
| output0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
| ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | |||||
| GraphExecutionContext graph_context; | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| graph_context.model = &hybrid_model; | |||||
| uint64_t value_0 = 110; | |||||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
| const std::vector<TensorValue> inputs{ in_tensor0 }; | |||||
| uint64_t value_1 = 123; | |||||
| TensorValue out_tensor0(&value_1, sizeof(value_1)); | |||||
| const std::vector<TensorValue> outputs{ out_tensor0 }; | |||||
| auto input_desc = output0->GetOpDesc()->GetInputDescPtr(0); | |||||
| const std::vector<ConstGeTensorDescPtr> input_descs{ input_desc }; | |||||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); | |||||
| ASSERT_EQ(executor.ExecuteAsync(inputs, input_descs, outputs), SUCCESS); | |||||
| } | |||||
| TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| CreateSimpleCondGraph(*graph); | |||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| Buffer weights_buffer(1024, 0x76); | |||||
| ge_sub_model->SetWeight(weights_buffer); | |||||
| ge_sub_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
| ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | |||||
| GraphExecutionContext graph_context; | |||||
| graph_context.model = &hybrid_model; | |||||
| graph_context.allocator = NpuMemoryAllocator::GetAllocator(0); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); | |||||
| uint64_t value_0 = 110; | |||||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
| uint64_t value_1 = 110; | |||||
| TensorValue in_tensor1(&value_1, sizeof(value_1)); | |||||
| const std::vector<TensorValue> inputs{ in_tensor0, in_tensor1 }; | |||||
| uint64_t value_2 = 123; | |||||
| TensorValue out_tensor0(&value_2, sizeof(value_2)); | |||||
| const std::vector<TensorValue> outputs{ out_tensor0 }; | |||||
| GeTensorDescPtr tensor_desc = make_shared<GeTensorDesc>(GeShape(), FORMAT_ND, DT_INT64); | |||||
| TensorUtils::SetSize(*tensor_desc, 64); | |||||
| const std::vector<ConstGeTensorDescPtr> input_desc{ tensor_desc, tensor_desc }; | |||||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); | |||||
| ASSERT_EQ(executor.ExecuteAsync(inputs, input_desc, outputs), SUCCESS); | |||||
| ASSERT_EQ(graph_context.callback_manager->Destroy(), SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -26,6 +26,7 @@ | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "hybrid/executor/hybrid_model_executor.h" | #include "hybrid/executor/hybrid_model_executor.h" | ||||
| #include "hybrid/executor/worker/execution_engine.h" | #include "hybrid/executor/worker/execution_engine.h" | ||||
| #include "hybrid/executor/subgraph_executor.h" | |||||
| #undef private | #undef private | ||||
| #undef protected | #undef protected | ||||
| @@ -75,6 +76,10 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { | |||||
| node_item->output_start = 0; | node_item->output_start = 0; | ||||
| GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| hybrid_model.root_graph_item_ = std::unique_ptr<GraphItem>(new(std::nothrow)GraphItem()); | |||||
| execution_context.model = &hybrid_model; | |||||
| execution_context.profiling_level = 1; | execution_context.profiling_level = 1; | ||||
| SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
| @@ -85,7 +90,11 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { | |||||
| ExecutionEngine execution_engine; | ExecutionEngine execution_engine; | ||||
| ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ||||
| EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR); | |||||
| std::function<void()> callback; | |||||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | |||||
| executor.InitCallback(&node_state, callback); | |||||
| EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
| } | } | ||||
| TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | ||||
| @@ -105,6 +114,7 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
| GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | ||||
| HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
| hybrid_model.root_graph_item_ = std::unique_ptr<GraphItem>(new(std::nothrow)GraphItem()); | |||||
| execution_context.model = &hybrid_model; | execution_context.model = &hybrid_model; | ||||
| SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
| @@ -115,5 +125,9 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
| ExecutionEngine execution_engine; | ExecutionEngine execution_engine; | ||||
| ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ||||
| EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR); | |||||
| std::function<void()> callback; | |||||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | |||||
| executor.InitCallback(&node_state, callback); | |||||
| EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
| } | } | ||||