@@ -145,8 +145,10 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | ||||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | ||||
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch | return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch | ||||
} | } | ||||
return SUCCESS; | |||||
} | } | ||||
if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) { | if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) { | ||||
@@ -41,8 +41,9 @@ bool IsControlFlowV2Op(const std::string &op_type); | |||||
class OptionalMutexGuard { | class OptionalMutexGuard { | ||||
public: | public: | ||||
OptionalMutexGuard(std::mutex *mutex, const string &name); | |||||
OptionalMutexGuard(std::mutex *mutex, const std::string &name); | |||||
~OptionalMutexGuard(); | ~OptionalMutexGuard(); | ||||
private: | private: | ||||
std::mutex *mu_{nullptr}; | std::mutex *mu_{nullptr}; | ||||
std::string name_; | std::string name_; | ||||
@@ -816,6 +816,7 @@ set(PROFILING_MNG_TEST_FILES | |||||
set(HYBRID_TEST_FILES | set(HYBRID_TEST_FILES | ||||
"hybrid/ge_hybrid_unittest.cc" | "hybrid/ge_hybrid_unittest.cc" | ||||
"hybrid/known_node_executor_unittest.cc" | "hybrid/known_node_executor_unittest.cc" | ||||
"hybrid/executor/subgraph_executor_unittest.cc" | |||||
"hybrid/executor/worker/execution_engine_unittest.cc" | "hybrid/executor/worker/execution_engine_unittest.cc" | ||||
"hybrid/model/hybrid_model_builder_unittest.cc" | "hybrid/model/hybrid_model_builder_unittest.cc" | ||||
"hybrid/node_executor/rts/rts_node_task_unittest.cc" | "hybrid/node_executor/rts/rts_node_task_unittest.cc" | ||||
@@ -834,6 +835,8 @@ list(APPEND COMMON_SHARED_LIBRARIES | |||||
mmpa_stub | mmpa_stub | ||||
hccl_stub | hccl_stub | ||||
error_manager_stub | error_manager_stub | ||||
ascend_protobuf | |||||
json | |||||
) | ) | ||||
# build graph | # build graph | ||||
@@ -879,7 +882,7 @@ target_link_libraries(ge_ut_common PRIVATE | |||||
) | ) | ||||
# build common format | # build common format | ||||
add_library(ge_ut_common_format STATIC ${COMMON_SRC_FILES} ${COMMON_FORMAT_SRC_FILES} ${PROTO_HDRS}) | |||||
add_library(ge_ut_common_format STATIC ${COMMON_FORMAT_SRC_FILES} ${PROTO_HDRS}) | |||||
target_compile_definitions(ge_ut_common_format PRIVATE | target_compile_definitions(ge_ut_common_format PRIVATE | ||||
google=ascend_private | google=ascend_private | ||||
@@ -1056,7 +1059,6 @@ target_link_libraries(ge_single_op PRIVATE | |||||
# libge_mutiparts_utest | # libge_mutiparts_utest | ||||
add_executable(ut_libge_multiparts_utest | add_executable(ut_libge_multiparts_utest | ||||
${COMMON_TEST_FILES} | ${COMMON_TEST_FILES} | ||||
${COMMON_FORMAT_SRC_FILES} | |||||
${MULTI_PARTS_TEST_FILES} | ${MULTI_PARTS_TEST_FILES} | ||||
) | ) | ||||
@@ -1071,14 +1073,14 @@ target_compile_definitions(ut_libge_multiparts_utest PRIVATE | |||||
target_link_libraries(ut_libge_multiparts_utest | target_link_libraries(ut_libge_multiparts_utest | ||||
$<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common ge_single_op ge_ut_common | |||||
gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov | |||||
ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common | |||||
ge_single_op ge_ut_common_format ge_ut_common | |||||
gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov | |||||
) | ) | ||||
# libge_others_utest | # libge_others_utest | ||||
add_executable(ut_libge_others_utest | add_executable(ut_libge_others_utest | ||||
${COMMON_TEST_FILES} | ${COMMON_TEST_FILES} | ||||
${COMMON_FORMAT_SRC_FILES} | |||||
${PASS_TEST_FILES} | ${PASS_TEST_FILES} | ||||
${EXECUTE_TEST_FILES} | ${EXECUTE_TEST_FILES} | ||||
${OTHERS_TEST_FILES} | ${OTHERS_TEST_FILES} | ||||
@@ -1091,16 +1093,15 @@ target_compile_options(ut_libge_others_utest PRIVATE | |||||
target_link_libraries(ut_libge_others_utest | target_link_libraries(ut_libge_others_utest | ||||
$<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
ge_load_common ge_execute_common ge_ut_common | |||||
gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov | |||||
ge_load_common ge_execute_common ge_ut_common ge_ut_common_format | |||||
gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov | |||||
) | ) | ||||
# libge_kernel_utest | # libge_kernel_utest | ||||
add_executable(ut_libge_kernel_utest | add_executable(ut_libge_kernel_utest | ||||
${COMMON_TEST_FILES} | |||||
${COMMON_FORMAT_SRC_FILES} | |||||
${KERNEL_TEST_FILES} | |||||
${KERNEL_SRC_FILES} | |||||
${COMMON_TEST_FILES} | |||||
${KERNEL_TEST_FILES} | |||||
${KERNEL_SRC_FILES} | |||||
) | ) | ||||
target_compile_options(ut_libge_kernel_utest PRIVATE | target_compile_options(ut_libge_kernel_utest PRIVATE | ||||
@@ -1110,8 +1111,8 @@ target_compile_options(ut_libge_kernel_utest PRIVATE | |||||
target_link_libraries(ut_libge_kernel_utest | target_link_libraries(ut_libge_kernel_utest | ||||
$<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
ge_load_common ge_ut_common | |||||
gtest gtest_main gmock gmock_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov | |||||
ge_load_common ge_ut_common ge_ut_common_format | |||||
gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov | |||||
) | ) | ||||
# libge_distinct_load_utest | # libge_distinct_load_utest | ||||
@@ -1137,10 +1138,11 @@ target_compile_definitions(ut_libge_distinct_load_utest PRIVATE | |||||
) | ) | ||||
target_link_libraries(ut_libge_distinct_load_utest | target_link_libraries(ut_libge_distinct_load_utest | ||||
${COMMON_SHARED_LIBRARIES} | |||||
$<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
ge_execute_common ge_ut_common_format ge_load_common | |||||
ge_single_op ge_prepare_common | |||||
ge_optimize_common ge_build_common ge_partition_common ge_ut_common | |||||
gtest gtest_main gmock gmock_main ascend_protobuf json c_sec -lrt -ldl -lpthread -lgcov | |||||
-Wl,--whole-archive | |||||
ge_single_op | |||||
-Wl,--no-whole-archive | |||||
ge_execute_common ge_load_common | |||||
ge_prepare_common ge_optimize_common ge_build_common ge_partition_common ge_ut_common ge_ut_common_format | |||||
gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lpthread -lgcov | |||||
) | ) |
@@ -115,7 +115,7 @@ TEST_F(UtestGeExecutor, load_data_from_file) { | |||||
string test_smap = "/tmp/" + std::to_string(getpid()) + "_maps"; | string test_smap = "/tmp/" + std::to_string(getpid()) + "_maps"; | ||||
string self_smap = "/proc/" + std::to_string(getpid()) + "/maps"; | string self_smap = "/proc/" + std::to_string(getpid()) + "/maps"; | ||||
string copy_smap = "cp " + self_smap + " " + test_smap; | |||||
string copy_smap = "cp -f " + self_smap + " " + test_smap; | |||||
EXPECT_EQ(system(copy_smap.c_str()), 0); | EXPECT_EQ(system(copy_smap.c_str()), 0); | ||||
ModelData model_data; | ModelData model_data; | ||||
@@ -91,8 +91,8 @@ TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_release) { | |||||
// test kernel_ex_task_Release | // test kernel_ex_task_Release | ||||
TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) { | TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) { | ||||
DavinciModel model(0, nullptr); | DavinciModel model(0, nullptr); | ||||
model.runtime_param_.mem_base = (uint8_t *)0x12345; | |||||
model.runtime_param_.mem_size = 100332000; | |||||
model.runtime_param_.mem_size = 10240; | |||||
model.runtime_param_.mem_base = new uint8_t[model.runtime_param_.mem_size]; | |||||
rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
rtStreamCreate(&stream, 0); | rtStreamCreate(&stream, 0); | ||||
@@ -108,19 +108,20 @@ TEST_F(UtestKernelExTaskInfo, success_kernel_ex_task_info_copy) { | |||||
EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace empty. | EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace empty. | ||||
model.op_list_[0]->SetWorkspace({100331008}); // offset | |||||
model.op_list_[0]->SetWorkspace({1008}); // offset | |||||
model.op_list_[0]->SetWorkspaceBytes({0}); // length | model.op_list_[0]->SetWorkspaceBytes({0}); // length | ||||
EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is null. | EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is null. | ||||
model.op_list_[0]->SetWorkspace({100331008}); // offset | |||||
model.op_list_[0]->SetWorkspace({1208}); // offset | |||||
model.op_list_[0]->SetWorkspaceBytes({10}); // length | model.op_list_[0]->SetWorkspaceBytes({10}); // length | ||||
EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is small. | EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), FAILED); // workspace addr is small. | ||||
model.op_list_[0]->SetWorkspace({100331008}); // offset | |||||
model.op_list_[0]->SetWorkspace({1308}); // offset | |||||
model.op_list_[0]->SetWorkspaceBytes({150}); // length | model.op_list_[0]->SetWorkspaceBytes({150}); // length | ||||
EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), SUCCESS); | EXPECT_EQ(kernel_ex_task_info.Init(task_def, &model), SUCCESS); | ||||
task_def.clear_kernel_ex(); | task_def.clear_kernel_ex(); | ||||
delete [] model.runtime_param_.mem_base; | |||||
model.runtime_param_.mem_base = nullptr; | model.runtime_param_.mem_base = nullptr; | ||||
} | } | ||||
@@ -418,6 +418,6 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) { | |||||
vector<InputTensorInfo> inputs; | vector<InputTensorInfo> inputs; | ||||
inputs.emplace_back(input_tensor); | inputs.emplace_back(input_tensor); | ||||
auto ret = mm.DataInputTensor(model_id,inputs); | auto ret = mm.DataInputTensor(model_id,inputs); | ||||
EXPECT_EQ(UNSUPPORTED, ret); | |||||
EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -0,0 +1,249 @@ | |||||
/** | |||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <gmock/gmock.h> | |||||
#include <vector> | |||||
#define private public | |||||
#define protected public | |||||
#include "hybrid/executor/subgraph_executor.h" | |||||
#include "hybrid/node_executor/node_executor.h" | |||||
#include "hybrid/node_executor/rts/rts_node_executor.h" | |||||
#include "hybrid/node_executor/ge_local/ge_local_node_executor.h" | |||||
#include "hybrid/model/hybrid_model_builder.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
using namespace hybrid; | |||||
class UtestSubgraphExecutor : public testing::Test { | |||||
protected: | |||||
void SetUp() { | |||||
NodeExecutorManager::GetInstance().engine_mapping_.clear(); | |||||
auto &engine_mapping = NodeExecutorManager::GetInstance().engine_mapping_; | |||||
engine_mapping.emplace("DNN_VM_RTS_OP_STORE", NodeExecutorManager::ExecutorType::RTS); | |||||
engine_mapping.emplace("DNN_VM_GE_LOCAL_OP_STORE", NodeExecutorManager::ExecutorType::GE_LOCAL); | |||||
NodeExecutorManager::GetInstance().executors_.clear(); | |||||
auto &task_executor = NodeExecutorManager::GetInstance().executors_; | |||||
task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new RtsNodeExecutor())); | |||||
task_executor.emplace(NodeExecutorManager::ExecutorType::GE_LOCAL, std::unique_ptr<NodeExecutor>(new GeLocalNodeExecutor())); | |||||
} | |||||
void TearDown() { | |||||
NodeExecutorManager::GetInstance().engine_mapping_.clear(); | |||||
NodeExecutorManager::GetInstance().executors_.clear(); | |||||
} | |||||
}; | |||||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
op_desc->SetStreamId(0); | |||||
static int32_t index = 0; | |||||
op_desc->SetId(index++); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
TensorUtils::SetSize(tensor, 64); | |||||
vector<int64_t> input_offset; | |||||
for (int i = 0; i < in_num; i++) { | |||||
op_desc->AddInputDesc(tensor); | |||||
input_offset.emplace_back(index * 64 + i * 64); | |||||
} | |||||
op_desc->SetInputOffset(input_offset); | |||||
vector<int64_t> output_offset; | |||||
for (int i = 0; i < out_num; i++) { | |||||
op_desc->AddOutputDesc(tensor); | |||||
output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); | |||||
} | |||||
op_desc->SetOutputOffset(output_offset); | |||||
op_desc->SetWorkspace({}); | |||||
op_desc->SetWorkspaceBytes({}); | |||||
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
return graph.AddNode(op_desc); | |||||
} | |||||
static void CreateSimpleCondGraph(ComputeGraph &graph) { | |||||
/******************************************************************************* | |||||
* | | |||||
* Merge | |||||
* / \. | |||||
* / \. | |||||
* / \. | |||||
* Add Sub | |||||
* | \ / | | |||||
* | \ _ / | | |||||
* | / \ | | |||||
* | / \ | | |||||
* Switch Switch | |||||
* | \ / | | |||||
* | \ / | | |||||
* | \ / | | |||||
* | \ / | | |||||
* | Less | | |||||
* | / \ | | |||||
* | / \ | | |||||
* Data Data | |||||
******************************************************************************/ | |||||
const auto data0 = CreateNode(graph, "data", DATA, 1, 1); | |||||
const auto data1 = CreateNode(graph, "data1", DATA, 1, 1); | |||||
data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
data1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
const auto const0 = CreateNode(graph, "const", CONSTANT, 0, 1); | |||||
const auto const1 = CreateNode(graph, "const1", CONSTANT, 0, 1); | |||||
const0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
const1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
{ | |||||
uint64_t const_value = 0; | |||||
const auto op_desc = const0->GetOpDesc(); | |||||
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||||
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | |||||
} | |||||
{ | |||||
uint64_t const_value = 1; | |||||
const auto op_desc = const1->GetOpDesc(); | |||||
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||||
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | |||||
} | |||||
const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); | |||||
const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | |||||
const auto switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | |||||
const auto switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0); | |||||
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); | |||||
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); | |||||
const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); | |||||
const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); | |||||
const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | |||||
const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | |||||
const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); | |||||
const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); | |||||
output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
GraphUtils::AddEdge(data0->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(const0->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(less1->GetOutControlAnchor(), active1->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_f->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(data0->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), add1->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(add1->GetOutControlAnchor(), active2->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(active2->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(data0->GetOutDataAnchor(0), sub1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), sub1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(sub1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), sub1->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
} | |||||
TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
const auto data0 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
const auto output0 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
GraphUtils::AddEdge(data0->GetOutDataAnchor(0), output0->GetInDataAnchor(0)); | |||||
data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
output0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
ge_root_model->SetModelName("test_name"); | |||||
GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
HybridModel hybrid_model(ge_root_model); | |||||
HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | |||||
GraphExecutionContext graph_context; | |||||
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
graph_context.model = &hybrid_model; | |||||
uint64_t value_0 = 110; | |||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
const std::vector<TensorValue> inputs{ in_tensor0 }; | |||||
uint64_t value_1 = 123; | |||||
TensorValue out_tensor0(&value_1, sizeof(value_1)); | |||||
const std::vector<TensorValue> outputs{ out_tensor0 }; | |||||
auto input_desc = output0->GetOpDesc()->GetInputDescPtr(0); | |||||
const std::vector<ConstGeTensorDescPtr> input_descs{ input_desc }; | |||||
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); | |||||
ASSERT_EQ(executor.ExecuteAsync(inputs, input_descs, outputs), SUCCESS); | |||||
} | |||||
TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
CreateSimpleCondGraph(*graph); | |||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
ge_root_model->SetModelName("test_name"); | |||||
GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
Buffer weights_buffer(1024, 0x76); | |||||
ge_sub_model->SetWeight(weights_buffer); | |||||
ge_sub_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | |||||
HybridModel hybrid_model(ge_root_model); | |||||
HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | |||||
GraphExecutionContext graph_context; | |||||
graph_context.model = &hybrid_model; | |||||
graph_context.allocator = NpuMemoryAllocator::GetAllocator(0); | |||||
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); | |||||
uint64_t value_0 = 110; | |||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
uint64_t value_1 = 110; | |||||
TensorValue in_tensor1(&value_1, sizeof(value_1)); | |||||
const std::vector<TensorValue> inputs{ in_tensor0, in_tensor1 }; | |||||
uint64_t value_2 = 123; | |||||
TensorValue out_tensor0(&value_2, sizeof(value_2)); | |||||
const std::vector<TensorValue> outputs{ out_tensor0 }; | |||||
GeTensorDescPtr tensor_desc = make_shared<GeTensorDesc>(GeShape(), FORMAT_ND, DT_INT64); | |||||
TensorUtils::SetSize(*tensor_desc, 64); | |||||
const std::vector<ConstGeTensorDescPtr> input_desc{ tensor_desc, tensor_desc }; | |||||
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); | |||||
ASSERT_EQ(executor.ExecuteAsync(inputs, input_desc, outputs), SUCCESS); | |||||
ASSERT_EQ(graph_context.callback_manager->Destroy(), SUCCESS); | |||||
} | |||||
} // namespace ge |
@@ -26,6 +26,7 @@ | |||||
#include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
#include "hybrid/executor/hybrid_model_executor.h" | #include "hybrid/executor/hybrid_model_executor.h" | ||||
#include "hybrid/executor/worker/execution_engine.h" | #include "hybrid/executor/worker/execution_engine.h" | ||||
#include "hybrid/executor/subgraph_executor.h" | |||||
#undef private | #undef private | ||||
#undef protected | #undef protected | ||||
@@ -75,6 +76,10 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { | |||||
node_item->output_start = 0; | node_item->output_start = 0; | ||||
GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
HybridModel hybrid_model(ge_root_model); | |||||
hybrid_model.root_graph_item_ = std::unique_ptr<GraphItem>(new(std::nothrow)GraphItem()); | |||||
execution_context.model = &hybrid_model; | |||||
execution_context.profiling_level = 1; | execution_context.profiling_level = 1; | ||||
SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
@@ -85,7 +90,11 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { | |||||
ExecutionEngine execution_engine; | ExecutionEngine execution_engine; | ||||
ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ||||
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR); | |||||
std::function<void()> callback; | |||||
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | |||||
executor.InitCallback(&node_state, callback); | |||||
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
} | } | ||||
TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | ||||
@@ -105,6 +114,7 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | ||||
HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
hybrid_model.root_graph_item_ = std::unique_ptr<GraphItem>(new(std::nothrow)GraphItem()); | |||||
execution_context.model = &hybrid_model; | execution_context.model = &hybrid_model; | ||||
SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
@@ -115,5 +125,9 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
ExecutionEngine execution_engine; | ExecutionEngine execution_engine; | ||||
ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | ||||
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR); | |||||
std::function<void()> callback; | |||||
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | |||||
executor.InitCallback(&node_state, callback); | |||||
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
} | } |