| @@ -6,7 +6,6 @@ if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | |||||
| add_subdirectory(offline) | add_subdirectory(offline) | ||||
| elseif (ENABLE_D) | elseif (ENABLE_D) | ||||
| add_subdirectory(common) | add_subdirectory(common) | ||||
| add_subdirectory(ge_runtime) | |||||
| endif () | endif () | ||||
| set(GRAPHENGINE_PROTO_LIST | set(GRAPHENGINE_PROTO_LIST | ||||
| @@ -1,78 +0,0 @@ | |||||
| ############ libge_runtime.so ############ | |||||
| set(GE_SRC_LIST | |||||
| "model_runner.cc" | |||||
| "runtime_model.cc" | |||||
| "output.cc" | |||||
| "task/aicpu_task.cc" | |||||
| "task/cce_task.cc" | |||||
| "task/tbe_task.cc" | |||||
| "task/event_record_task.cc" | |||||
| "task/event_wait_task.cc" | |||||
| "task/stream_active_task.cc" | |||||
| "task/stream_switch_task.cc" | |||||
| "task/hccl_task.cc" | |||||
| "task/memcpy_async_task.cc" | |||||
| "task/profiler_task.cc" | |||||
| "task/label_goto_task.cc" | |||||
| "task/label_set_task.cc" | |||||
| "task/label_switch_task.cc" | |||||
| ) | |||||
| add_library(ge_runtime SHARED ${GE_SRC_LIST}) | |||||
| target_compile_options(ge_runtime PRIVATE | |||||
| -Werror | |||||
| -O2 | |||||
| -Wno-deprecated-declarations | |||||
| -fno-common | |||||
| ) | |||||
| target_compile_definitions(ge_runtime PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| LOG_CPP | |||||
| ) | |||||
| target_include_directories(ge_runtime PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${GE_CODE_DIR} | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/graph | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${GE_CODE_DIR}/inc/framework/common | |||||
| ${GE_CODE_DIR}/inc/framework/ge_runtime | |||||
| ${GE_CODE_DIR}/inc/cce | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ${METADEF_DIR} | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ) | |||||
| target_link_options(ge_runtime PRIVATE | |||||
| -Wl,-Bsymbolic | |||||
| ) | |||||
| target_link_libraries(ge_runtime PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| -Wl,--no-as-needed | |||||
| slog | |||||
| runtime | |||||
| c_sec | |||||
| graph | |||||
| -Wl,--as-needed | |||||
| -lrt | |||||
| -ldl | |||||
| ) | |||||
| ############ install ############ | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(TARGETS ge_runtime OPTIONAL | |||||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||||
| ) | |||||
| @@ -1,62 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_MODEL_CONTEXT_H_ | |||||
| #define GE_GE_RUNTIME_MODEL_CONTEXT_H_ | |||||
| #include <vector> | |||||
| #include "runtime/rt_model.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class ModelContext { | |||||
| public: | |||||
| ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, | |||||
| rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, | |||||
| const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) | |||||
| : device_id_(device_id), | |||||
| session_id_(session_id), | |||||
| priority_(priority), | |||||
| rt_model_handle_(rt_model_handle), | |||||
| rt_model_stream_(rt_model_stream), | |||||
| stream_list_(stream_list), | |||||
| label_list_(label_list), | |||||
| event_list_(event_list) {} | |||||
| ~ModelContext() {} | |||||
| uint64_t device_id() const { return device_id_; } | |||||
| uint64_t session_id() const { return session_id_; } | |||||
| int32_t priority() const { return priority_; } | |||||
| const rtModel_t &rt_model_handle() const { return rt_model_handle_; } | |||||
| const rtStream_t &rt_model_stream() const { return rt_model_stream_; } | |||||
| const std::vector<rtStream_t> &stream_list() const { return stream_list_; } | |||||
| const std::vector<rtLabel_t> &label_list() const { return label_list_; } | |||||
| const std::vector<rtEvent_t> &event_list() const { return event_list_; } | |||||
| private: | |||||
| uint32_t device_id_; | |||||
| uint64_t session_id_; | |||||
| int32_t priority_; | |||||
| rtModel_t rt_model_handle_; | |||||
| rtStream_t rt_model_stream_; | |||||
| std::vector<rtStream_t> stream_list_; | |||||
| std::vector<rtLabel_t> label_list_; | |||||
| std::vector<rtEvent_t> event_list_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_MODEL_CONTEXT_H_ | |||||
| @@ -1,173 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/model_runner.h" | |||||
| #include "./runtime_model.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "ge_runtime/davinci_model.h" | |||||
| #include "graph/op_desc.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| using RuntimeModelPtr = std::shared_ptr<RuntimeModel>; | |||||
| using DavinciModelPtr = std::shared_ptr<DavinciModel>; | |||||
| ModelRunner &ModelRunner::Instance() { | |||||
| static ModelRunner instance; // Guaranteed to be destroyed. | |||||
| return instance; | |||||
| } | |||||
| bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||||
| std::shared_ptr<DavinciModel> davinci_model, | |||||
| std::shared_ptr<ModelListener> listener) { | |||||
| std::shared_ptr<RuntimeModel> model = MakeShared<RuntimeModel>(); | |||||
| if (model == nullptr) { | |||||
| return false; | |||||
| } | |||||
| bool status = model->Load(device_id, session_id, davinci_model); | |||||
| if (!status) { | |||||
| return false; | |||||
| } | |||||
| runtime_models_[model_id] = model; | |||||
| return true; | |||||
| } | |||||
| bool ModelRunner::DistributeTask(uint32_t model_id) { | |||||
| auto model_iter = runtime_models_.find(model_id); | |||||
| if (model_iter == runtime_models_.end()) { | |||||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
| return false; | |||||
| } | |||||
| return model_iter->second->DistributeTask(); | |||||
| } | |||||
| bool ModelRunner::LoadModelComplete(uint32_t model_id) { | |||||
| auto model_iter = runtime_models_.find(model_id); | |||||
| if (model_iter == runtime_models_.end()) { | |||||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
| return false; | |||||
| } | |||||
| return model_iter->second->LoadComplete(); | |||||
| } | |||||
| const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const { | |||||
| auto model_iter = runtime_models_.find(model_id); | |||||
| if (model_iter == runtime_models_.end()) { | |||||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
| static const std::vector<uint32_t> empty_ret; | |||||
| return empty_ret; | |||||
| } | |||||
| return model_iter->second->GetTaskIdList(); | |||||
| } | |||||
| const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const { | |||||
| auto model_iter = runtime_models_.find(model_id); | |||||
| if (model_iter == runtime_models_.end()) { | |||||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
| static const std::vector<uint32_t> empty_ret; | |||||
| return empty_ret; | |||||
| } | |||||
| return model_iter->second->GetStreamIdList(); | |||||
| } | |||||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { | |||||
| auto model_iter = runtime_models_.find(model_id); | |||||
| if (model_iter == runtime_models_.end()) { | |||||
| GELOGW("Model id %u not found.", model_id); | |||||
| static const std::map<std::string, std::shared_ptr<RuntimeInfo>> empty_ret; | |||||
| return empty_ret; | |||||
| } | |||||
| return model_iter->second->GetRuntimeInfoMap(); | |||||
| } | |||||
| void *ModelRunner::GetModelHandle(uint32_t model_id) const { | |||||
| auto model_iter = runtime_models_.find(model_id); | |||||
| if (model_iter == runtime_models_.end()) { | |||||
| GELOGW("Model id %u not found.", model_id); | |||||
| return nullptr; | |||||
| } | |||||
| return model_iter->second->GetModelHandle(); | |||||
| } | |||||
| bool ModelRunner::UnloadModel(uint32_t model_id) { | |||||
| auto iter = runtime_models_.find(model_id); | |||||
| if (iter != runtime_models_.end()) { | |||||
| (void)runtime_models_.erase(iter); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool ModelRunner::RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data) { | |||||
| if (output_data == nullptr) { | |||||
| GELOGW("Output data point is null."); | |||||
| } | |||||
| auto model_iter = runtime_models_.find(model_id); | |||||
| if (model_iter == runtime_models_.end()) { | |||||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
| return false; | |||||
| } | |||||
| bool status = model_iter->second->CopyInputData(input_data); | |||||
| if (!status) { | |||||
| GELOGE(FAILED, "Copy input data fail."); | |||||
| return false; | |||||
| } | |||||
| status = model_iter->second->Run(); | |||||
| if (!status) { | |||||
| GELOGE(FAILED, "Run model fail."); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ModelRunner::GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, | |||||
| std::vector<InputOutputDescInfo> *input_desc, | |||||
| std::vector<InputOutputDescInfo> *output_desc, | |||||
| std::vector<uint32_t> *input_format, std::vector<uint32_t> *output_format) { | |||||
| if (runtime_models_.find(model_id) == runtime_models_.end()) { | |||||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||||
| return false; | |||||
| } | |||||
| auto model = runtime_models_[model_id]; | |||||
| if (input_desc == nullptr || output_desc == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "input_desc or output_desc is null."); | |||||
| return false; | |||||
| } | |||||
| bool status = model->GetInputOutputDescInfo(zero_copy, input_desc, output_desc, input_format, output_format); | |||||
| if (!status) { | |||||
| GELOGE(FAILED, "Get input output desc info fail."); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,66 +0,0 @@ | |||||
| LOCAL_PATH := $(call my-dir) | |||||
| # task.proto is old task, add it for ops_kernel_info_store | |||||
| local_ge_runtime_src_files := \ | |||||
| model_runner.cc \ | |||||
| runtime_model.cc \ | |||||
| output.cc \ | |||||
| task/aicpu_task.cc \ | |||||
| task/cce_task.cc \ | |||||
| task/tbe_task.cc \ | |||||
| task/event_record_task.cc \ | |||||
| task/event_wait_task.cc \ | |||||
| task/stream_active_task.cc \ | |||||
| task/stream_switch_task.cc \ | |||||
| task/hccl_task.cc \ | |||||
| task/memcpy_async_task.cc \ | |||||
| task/profiler_task.cc \ | |||||
| local_ge_runtime_include := \ | |||||
| $(LOCAL_PATH)/ \ | |||||
| $(TOPDIR)libc_sec/include \ | |||||
| $(TOPDIR)inc/external \ | |||||
| $(TOPDIR)inc/external/graph \ | |||||
| $(TOPDIR)inc/framework \ | |||||
| $(TOPDIR)inc/graph \ | |||||
| $(TOPDIR)inc \ | |||||
| $(LOCAL_PATH)/../ \ | |||||
| third_party/protobuf/include | |||||
| local_ge_runtime_shared_library := \ | |||||
| libruntime \ | |||||
| libslog \ | |||||
| libc_sec | |||||
| local_ge_runtime_ldflags := -lrt -ldl | |||||
| # compile device libge_runtime | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libge_runtime | |||||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 | |||||
| LOCAL_CFLAGS += -Werror | |||||
| LOCAL_SRC_FILES := $(local_ge_runtime_src_files) | |||||
| LOCAL_C_INCLUDES := $(local_ge_runtime_include) | |||||
| LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library) | |||||
| LOCAL_LDFLAGS += $(local_ge_runtime_ldflags) | |||||
| include $(BUILD_SHARED_LIBRARY) | |||||
| # compile host libge_runtime | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libge_runtime | |||||
| LOCAL_CFLAGS += -Werror | |||||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| ifeq ($(DEBUG), 1) | |||||
| LOCAL_CFLAGS += -g -O0 | |||||
| else | |||||
| LOCAL_CFLAGS += -O2 | |||||
| endif | |||||
| LOCAL_SRC_FILES := $(local_ge_runtime_src_files) | |||||
| LOCAL_C_INCLUDES := $(local_ge_runtime_include) | |||||
| LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library) | |||||
| LOCAL_LDFLAGS += $(local_ge_runtime_ldflags) | |||||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||||
| @@ -1,94 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/output.h" | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| Output::Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model) | |||||
| : model_(model), op_info_(op_info), input_num_(0) {} | |||||
| Output::~Output() {} | |||||
| bool Output::Init() { | |||||
| if (op_info_ == nullptr || model_ == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "The op_desc_ or model_ is nullptr."); | |||||
| return false; | |||||
| } | |||||
| input_num_ = op_info_->input_tensors.size(); | |||||
| v_input_size_.clear(); | |||||
| v_input_data_addr_.clear(); | |||||
| auto input_vector = op_info_->input_addrs; | |||||
| if (input_num_ != input_vector.size()) { | |||||
| GELOGE(INTERNAL_ERROR, "The input desc size: %zu != input addr size: %zu.", input_num_, input_vector.size()); | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < input_num_; i++) { | |||||
| uint32_t tensorSize = 0; | |||||
| const auto &input_info = op_info_->input_tensors.at(i); | |||||
| tensorSize = input_info.size; | |||||
| v_input_size_.push_back(tensorSize); | |||||
| v_input_data_addr_.push_back(reinterpret_cast<uint8_t *>(input_vector.at(i))); | |||||
| } | |||||
| GELOGI("Init output:%zu, %zu, %zu", input_num_, v_input_size_.size(), v_input_data_addr_.size()); | |||||
| return true; | |||||
| } | |||||
| /// | |||||
| /// @ingroup domi_ome | |||||
| /// @brief Copy Op Output to user space. | |||||
| /// @brief when model running, Add one DataOp as input node, Add one Output Op as output node. | |||||
| /// @return Status | |||||
| /// | |||||
| bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share) { | |||||
| if (rslt == nullptr) { | |||||
| GELOGE(FAILED, "OutputData is null."); | |||||
| return false; | |||||
| } | |||||
| uint32_t data_count = 0; | |||||
| if (v_input_size_.empty() || v_input_data_addr_.empty()) { | |||||
| GELOGE(INTERNAL_ERROR, "v_output_size_ or v_output_data_addr_ is empty!"); | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < input_num_; i++) { | |||||
| DataBuffer data_buf = rslt->blobs[data_begin + data_count]; | |||||
| bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); | |||||
| if (!ret) { | |||||
| GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]); | |||||
| return ret; | |||||
| } | |||||
| data_index = data_begin + data_count; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool Output::SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, | |||||
| bool support_mem_share) { | |||||
| return true; | |||||
| } | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,53 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_OUTPUT_H_ | |||||
| #define GE_GE_RUNTIME_OUTPUT_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "ge_runtime/davinci_model.h" | |||||
| #include "common/ge_types.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class Output { | |||||
| public: | |||||
| Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model); | |||||
| virtual ~Output(); | |||||
| bool Init(); | |||||
| bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); | |||||
| bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, bool support_mem_share); | |||||
| // Copy assignment operator and copy constructor are deleted | |||||
| Output &operator=(const Output &output) = delete; | |||||
| Output(const Output &output) = delete; | |||||
| protected: | |||||
| std::shared_ptr<DavinciModel> model_; | |||||
| OpInfoPtr op_info_; | |||||
| // Input descriptions | |||||
| size_t input_num_; | |||||
| vector<void *> v_input_data_addr_; // Init as:buf_base + op_def_->input(i)); | |||||
| vector<uint32_t> v_input_size_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_OUTPUT_H_ | |||||
| @@ -1,27 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
| // source: task.proto | |||||
| #ifndef STUB_TASK_PROTO_H | |||||
| #define STUB_TASK_PROTO_H | |||||
| namespace domi { | |||||
| class TaskDef; | |||||
| } | |||||
| #endif // STUB_TASK_PROTO_H | |||||
| @@ -1,547 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/runtime_model.h" | |||||
| #include <set> | |||||
| #include "./model_context.h" | |||||
| #include "./task/task.h" | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "common/types.h" | |||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/op/op_parser_util.h" | |||||
| #include "graph/types.h" | |||||
| #include "task/task_factory.h" | |||||
| #include "ge/common/math/math_util.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| namespace { | |||||
| const int kOffsetUnit = 8; | |||||
| const uint32_t kStringHeadElems = 2; | |||||
| } // namespace | |||||
| RuntimeModel::~RuntimeModel() { | |||||
| GELOGI("RuntimeModel destructor start"); | |||||
| // Unbind rtModel from all task related streams | |||||
| RtModelUnbindStream(); | |||||
| // Release task first, hccl task hold stream | |||||
| task_list_.clear(); | |||||
| // Release all task related streams | |||||
| RtStreamDestory(); | |||||
| // Release rtlabel resource | |||||
| RtLabelDestory(); | |||||
| // Release rtEvent resourece | |||||
| RtEventDestory(); | |||||
| GELOGI("Do RtModelDestory"); | |||||
| // Release all rt_model | |||||
| RtModelDestory(); | |||||
| } | |||||
| bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| if (davinci_model == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Davinci model is null."); | |||||
| return false; | |||||
| } | |||||
| std::set<int64_t> wait_active_streams; | |||||
| std::set<int64_t> force_copy_streams; | |||||
| for (const auto &stream_id : davinci_model->GetWaitActiveStreams()) { | |||||
| GELOGI("stream id %u is wait active stream.", stream_id); | |||||
| (void)wait_active_streams.insert(stream_id); | |||||
| } | |||||
| for (const auto &stream_id : davinci_model->GetForceCopyStreams()) { | |||||
| GELOGI("stream id %u is force copy stream.", stream_id); | |||||
| (void)force_copy_streams.insert(stream_id); | |||||
| } | |||||
| GELOGI("stream number:%u", davinci_model->GetStreamNum()); | |||||
| for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { | |||||
| rtStream_t stream = nullptr; | |||||
| uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) | |||||
| ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||||
| : (RT_STREAM_PERSISTENT); | |||||
| rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtStreamCreate failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("rtStreamCreateWithFlags end."); | |||||
| stream_list_.emplace_back(stream); | |||||
| // Bind rt_model_handle_ to all task related streams | |||||
| flag = (wait_active_streams.find(i) != wait_active_streams.end()) ? (static_cast<uint32_t>(RT_INVALID_FLAG)) | |||||
| : (static_cast<uint32_t>(RT_HEAD_STREAM)); | |||||
| rt_ret = rtModelBindStream(rt_model_handle_, stream, flag); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtModelBindStream failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("stream index:%u, stream:%p.", i, stream); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::InitEvent(uint32_t event_num) { | |||||
| GELOGI("event number:%u.", event_num); | |||||
| for (uint32_t i = 0; i < event_num; ++i) { | |||||
| rtEvent_t rt_event; | |||||
| rtError_t rt_ret = rtEventCreate(&rt_event); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtEventCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||||
| return false; | |||||
| } | |||||
| event_list_.push_back(rt_event); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||||
| label_list_.resize(davinci_model->GetBatchNum()); | |||||
| for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||||
| if (task_info == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "task_info is null."); | |||||
| continue; | |||||
| } | |||||
| if (task_info->type() != TaskInfoType::LABEL_SET) { | |||||
| continue; | |||||
| } | |||||
| auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||||
| if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||||
| GELOGE(PARAM_INVALID, "Invalid stream id."); | |||||
| return false; | |||||
| } | |||||
| rtLabel_t rt_label = nullptr; | |||||
| rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| label_list_[label_set_task_info->label_id()] = rt_label; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| GELOGI("InitResource start"); | |||||
| if (davinci_model == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "davinci model is null"); | |||||
| return false; | |||||
| } | |||||
| rtError_t rt_ret = rtModelCreate(&rt_model_handle_, 0); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtModelCreate failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| // Create rtStream for rt_model_handle_ | |||||
| rt_ret = rtStreamCreate(&rt_model_stream_, davinci_model->GetPriority()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtStreamCreate failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("rtStreamCreate end"); | |||||
| if (!InitStream(davinci_model)) { | |||||
| return false; | |||||
| } | |||||
| if (!InitEvent(davinci_model->GetEventNum())) { | |||||
| return false; | |||||
| } | |||||
| if (!InitLabel(davinci_model)) { | |||||
| return false; | |||||
| } | |||||
| GELOGI("InitResource succ"); | |||||
| return true; | |||||
| } | |||||
| void RuntimeModel::GenerateTask(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| GELOGI("GenerateTask start."); | |||||
| if (davinci_model == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "davinci model is null"); | |||||
| return; | |||||
| } | |||||
| auto task_infos = davinci_model->GetTaskInfoList(); | |||||
| ModelContext model_context(device_id, session_id, davinci_model->GetPriority(), rt_model_handle_, rt_model_stream_, | |||||
| stream_list_, label_list_, event_list_); | |||||
| for (auto &task_info : task_infos) { | |||||
| auto task = TaskFactory::GetInstance().Create(model_context, task_info); | |||||
| task_list_.push_back(task); | |||||
| } | |||||
| GELOGI("GenerateTask succ."); | |||||
| } | |||||
| bool RuntimeModel::LoadTask() { | |||||
| GELOGI("LoadTask start."); | |||||
| for (auto &task : task_list_) { | |||||
| if (task == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "task is null."); | |||||
| continue; | |||||
| } | |||||
| bool ret = task->Distribute(); | |||||
| if (!ret) { | |||||
| GELOGE(FAILED, "task distribute fail."); | |||||
| return false; | |||||
| } | |||||
| uint32_t task_id = 0; | |||||
| uint32_t stream_id = 0; | |||||
| rtError_t rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X.", rt_ret); | |||||
| return false; | |||||
| } | |||||
| task_id_list_.push_back(task_id); | |||||
| stream_id_list_.push_back(stream_id); | |||||
| if (task->Args() != nullptr) { | |||||
| std::shared_ptr<RuntimeInfo> runtime_tuple = nullptr; | |||||
| GE_MAKE_SHARED(runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args()), return false); | |||||
| auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple); | |||||
| if (!emplace_ret.second) { | |||||
| GELOGW("Task name exist:%s", task->task_name().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (task_list_.empty()) { | |||||
| GELOGE(FAILED, "Task list is empty"); | |||||
| return false; | |||||
| } | |||||
| GELOGI("LoadTask succ."); | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::LoadComplete() { | |||||
| uint32_t task_id = 0; | |||||
| uint32_t stream_id = 0; | |||||
| auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| task_id_list_.push_back(task_id); | |||||
| stream_id_list_.push_back(stream_id); | |||||
| rt_ret = rtModelLoadComplete(rt_model_handle_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| bool status = InitResource(davinci_model); | |||||
| if (!status) { | |||||
| GELOGE(FAILED, "InitResource failed."); | |||||
| return status; | |||||
| } | |||||
| status = InitDataInfo(davinci_model); | |||||
| if (!status) { | |||||
| GELOGE(FAILED, "InitDataInfo failed."); | |||||
| return status; | |||||
| } | |||||
| status = InitOutputInfo(davinci_model); | |||||
| if (!status) { | |||||
| GELOGE(FAILED, "InitOutputInfo failed."); | |||||
| return status; | |||||
| } | |||||
| status = InitConstantInfo(davinci_model); | |||||
| if (!status) { | |||||
| GELOGE(FAILED, "InitConstantInfo failed."); | |||||
| return status; | |||||
| } | |||||
| GenerateTask(device_id, session_id, davinci_model); | |||||
| return status; | |||||
| } | |||||
| bool RuntimeModel::DistributeTask() { | |||||
| bool status = LoadTask(); | |||||
| if (!status) { | |||||
| GELOGE(FAILED, "DistributeTask failed"); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::Run() { | |||||
| GELOGI("Davinci task run start"); | |||||
| rtError_t ret = rtModelExecute(rt_model_handle_, rt_model_stream_, 0); | |||||
| if (ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Model execute failed, ret = 0x%X", ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("Run rtModelExecute success, ret = 0x%X", ret); | |||||
| ret = rtStreamSynchronize(rt_model_stream_); | |||||
| if (ret != RT_ERROR_NONE) { | |||||
| if (ret == ACL_ERROR_RT_END_OF_SEQUENCE) { | |||||
| GELOGI("Model stream ACL_ERROR_RT_END_OF_SEQUENCE signal received, ret = 0x%X", ret); | |||||
| return true; | |||||
| } | |||||
| GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("Davinci task run succ."); | |||||
| return true; | |||||
| } | |||||
| void RuntimeModel::RtModelUnbindStream() noexcept { | |||||
| for (size_t i = 0; i < stream_list_.size(); i++) { | |||||
| if (rtModelUnbindStream(rt_model_handle_, stream_list_[i]) != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Unbind stream from model failed! Index: %zu", i); | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| void RuntimeModel::RtStreamDestory() noexcept { | |||||
| if (rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Destroy stream for rt_model failed!"); | |||||
| return; | |||||
| } | |||||
| for (size_t i = 0; i < stream_list_.size(); i++) { | |||||
| if (rtStreamDestroy(stream_list_[i]) != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Destroy stream failed! Index: %zu", i); | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| void RuntimeModel::RtLabelDestory() noexcept { | |||||
| for (size_t i = 0; i < label_list_.size(); i++) { | |||||
| if (label_list_[i] == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| void RuntimeModel::RtModelDestory() noexcept { | |||||
| rtError_t ret = rtModelDestroy(rt_model_handle_); | |||||
| if (ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||||
| return; | |||||
| } | |||||
| } | |||||
| void RuntimeModel::RtEventDestory() noexcept { | |||||
| for (size_t i = 0; i < event_list_.size(); i++) { | |||||
| if (rtEventDestroy(event_list_[i]) != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Destroy event failed! Index: %zu", i); | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| bool RuntimeModel::InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model) { return true; } | |||||
| bool RuntimeModel::InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| if (davinci_model == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "davinci model is null"); | |||||
| return false; | |||||
| } | |||||
| output_info_list_ = davinci_model->GetOutputInfoList(); | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::CopyInputData(const InputData &input_data) { | |||||
| if (input_data.blobs.size() != data_info_list_.size()) { | |||||
| GELOGE(PARAM_INVALID, "The input data list size (%zu) does not match the model input list size (%zu)", | |||||
| input_data.blobs.size(), data_info_list_.size()); | |||||
| return false; | |||||
| } | |||||
| for (const auto &data_info : data_info_list_) { | |||||
| if (data_info == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "data info is null."); | |||||
| return false; | |||||
| } | |||||
| bool ret = CopyInputDataToModel(input_data.blobs, data_info); | |||||
| if (!ret) { | |||||
| GELOGE(FAILED, "Copy input data to model ret fail, data_info: %s, model id: %u", data_info->name.c_str(), | |||||
| input_data.model_id); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::CopyInputDataToModel(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) { | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::CopyHostData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) const { | |||||
| GELOGI("Start CopyHostData."); | |||||
| if (data.empty()) { | |||||
| GELOGE(PARAM_INVALID, "data buffer is empty."); | |||||
| return false; | |||||
| } | |||||
| if (data_info == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "data info is null."); | |||||
| return false; | |||||
| } | |||||
| void *host_data_addr = data[data_info->index].data; | |||||
| uint32_t copy_size = data[data_info->index].length; | |||||
| GELOGD("data output tensor is aipp tensor,copy data only."); | |||||
| const std::vector<uintptr_t> &outputs = data_info->output_addrs; | |||||
| if (outputs.empty()) { | |||||
| GELOGE(PARAM_INVALID, "Output addrs is empty."); | |||||
| return false; | |||||
| } | |||||
| // Copy input data to data nodes | |||||
| void *data_out_addr = reinterpret_cast<void *>(outputs[0]); | |||||
| rtError_t rt_ret = rtMemcpy(data_out_addr, copy_size, host_data_addr, copy_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::CopyTransData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) { | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| // Const no input, only 1 output, and this output has no data | |||||
| // weight data copy to output mem | |||||
| if (davinci_model == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Davinci model is null."); | |||||
| return false; | |||||
| } | |||||
| constant_info_list_ = davinci_model->GetConstantInfoList(); | |||||
| for (const auto &constant : constant_info_list_) { | |||||
| if (constant == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "constant is null"); | |||||
| continue; | |||||
| } | |||||
| if (constant->output_tensors.empty()) { | |||||
| GELOGE(PARAM_INVALID, "Output tensors is empty"); | |||||
| return false; | |||||
| } | |||||
| if (constant->weight_tensors.empty()) { | |||||
| GELOGE(PARAM_INVALID, "Weight tensors is empty"); | |||||
| return false; | |||||
| } | |||||
| if (constant->output_tensors[0].size < constant->weight_data.size()) { | |||||
| GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size, | |||||
| constant->weight_data.size()); | |||||
| return false; | |||||
| } | |||||
| if (constant->weight_data.empty()) { | |||||
| GELOGW("Const op:%s has no weight data.", constant->name.c_str()); | |||||
| continue; | |||||
| } | |||||
| if (constant->weight_tensors[0].datatype == DT_STRING) { | |||||
| /// If tensor is a scaler, it's shape size if zero, according ge_tensor.cc. | |||||
| /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | |||||
| /// and that of unknown shape is zero too. | |||||
| /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | |||||
| int64_t elem_num = | |||||
| (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||||
| if (constant->weight_data.size() < sizeof(uint64_t)) { | |||||
| GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | |||||
| return false; | |||||
| } | |||||
| uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); | |||||
| uint32_t head_len = kOffsetUnit * kStringHeadElems; | |||||
| if (CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { | |||||
| GELOGE(FAILED, "Shape size is invalid"); | |||||
| return false; | |||||
| } | |||||
| int64_t offset = elem_num * head_len; | |||||
| uintptr_t hbm_raw_data_base_addr = reinterpret_cast<uintptr_t>(constant->output_addrs[0]) + offset; | |||||
| for (int64_t i = elem_num - 1; i >= 0; --i) { | |||||
| buff[i * kStringHeadElems] = hbm_raw_data_base_addr + (buff[i * kStringHeadElems] - buff[0]); | |||||
| } | |||||
| } | |||||
| rtError_t rt_ret = rtMemcpy(reinterpret_cast<void *>(constant->output_addrs[0]), constant->output_tensors[0].size, | |||||
| constant->weight_data.data(), constant->weight_data.size(), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtGetFunctionByName failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||||
| std::vector<InputOutputDescInfo> *output_desc, | |||||
| std::vector<uint32_t> *input_format, std::vector<uint32_t> *output_format) { | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::GetInputDescInfo(std::vector<InputOutputDescInfo> *input_desc, std::vector<uint32_t> *formats) { | |||||
| return true; | |||||
| } | |||||
| bool RuntimeModel::GetOutputDescInfo(std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *formats) { | |||||
| return true; | |||||
| } | |||||
| void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutputDescInfo *output, | |||||
| uint32_t *format_result) {} | |||||
| const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } | |||||
| const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,92 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_RUNTIME_MODEL_H_ | |||||
| #define GE_GE_RUNTIME_RUNTIME_MODEL_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ge_runtime/davinci_model.h" | |||||
| #include "common/ge_types.h" | |||||
| #include "runtime/base.h" | |||||
| #include "runtime/rt_model.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||||
| class Task; | |||||
| class RuntimeModel { | |||||
| public: | |||||
| RuntimeModel() = default; | |||||
| ~RuntimeModel(); | |||||
| bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); | |||||
| bool DistributeTask(); | |||||
| bool LoadComplete(); | |||||
| const std::vector<uint32_t> &GetTaskIdList() const; | |||||
| const std::vector<uint32_t> &GetStreamIdList() const; | |||||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | |||||
| rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||||
| bool Run(); | |||||
| bool CopyInputData(const InputData &input_data); | |||||
| bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||||
| std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||||
| std::vector<uint32_t> *output_format); | |||||
| private: | |||||
| bool InitResource(std::shared_ptr<DavinciModel> &davinci_model); | |||||
| void GenerateTask(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); | |||||
| bool LoadTask(); | |||||
| bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | |||||
| bool InitEvent(uint32_t event_num); | |||||
| bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||||
| bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||||
| bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||||
| bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||||
| void RtModelUnbindStream() noexcept; | |||||
| void RtStreamDestory() noexcept; | |||||
| void RtModelDestory() noexcept; | |||||
| void RtLabelDestory() noexcept; | |||||
| void RtEventDestory() noexcept; | |||||
| bool CopyInputDataToModel(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info); | |||||
| bool CopyHostData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) const; | |||||
| bool CopyTransData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info); | |||||
| bool GetInputDescInfo(std::vector<InputOutputDescInfo> *input_desc, std::vector<uint32_t> *formats); | |||||
| bool GetOutputDescInfo(std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *formats); | |||||
| void CreateOutput(uint32_t index, const OpInfo &op_info, InputOutputDescInfo *output, uint32_t *format); | |||||
| rtModel_t rt_model_handle_{}; | |||||
| rtStream_t rt_model_stream_{}; | |||||
| std::vector<rtStream_t> stream_list_{}; | |||||
| std::vector<rtLabel_t> label_list_{}; | |||||
| std::vector<rtEvent_t> event_list_{}; | |||||
| std::vector<std::shared_ptr<Task>> task_list_{}; | |||||
| std::vector<std::shared_ptr<OpInfo>> data_info_list_{}; | |||||
| std::vector<std::shared_ptr<OpInfo>> output_info_list_{}; | |||||
| std::vector<std::shared_ptr<OpInfo>> constant_info_list_{}; | |||||
| std::vector<uint32_t> task_id_list_{}; | |||||
| std::vector<uint32_t> stream_id_list_{}; | |||||
| std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_RUNTIME_MODEL_H_ | |||||
| @@ -1,168 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/aicpu_task.h" | |||||
| #include <vector> | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| #include "aicpu/common/aicpu_task_struct.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info) | |||||
| : TaskRepeater<AicpuTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| args_(nullptr), | |||||
| ext_info_(nullptr), | |||||
| input_output_addr_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| if (stream_list.size() == 1) { | |||||
| stream_ = stream_list[0]; | |||||
| } else if (stream_list.size() > task_info->stream_id()) { | |||||
| stream_ = stream_list[task_info->stream_id()]; | |||||
| } else { | |||||
| GELOGW("index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); | |||||
| } | |||||
| } | |||||
| AicpuTask::~AicpuTask() { | |||||
| ReleaseRtMem(&args_); | |||||
| ReleaseRtMem(&ext_info_); | |||||
| } | |||||
| bool AicpuTask::Distribute() { | |||||
| GELOGI("InitAicpuTask start."); | |||||
| vector<void *> io_addrs; | |||||
| io_addrs.insert(io_addrs.end(), task_info_->input_data_addrs().begin(), task_info_->input_data_addrs().end()); | |||||
| io_addrs.insert(io_addrs.end(), task_info_->output_data_addrs().begin(), task_info_->output_data_addrs().end()); | |||||
| auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); | |||||
| auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); | |||||
| constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); | |||||
| uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; | |||||
| uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); | |||||
| uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + | |||||
| static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t); | |||||
| aicpu::AicpuParamHead aicpu_param_head; | |||||
| aicpu_param_head.length = args_size; | |||||
| aicpu_param_head.ioAddrNum = io_addrs_num; | |||||
| auto ext_info = task_info_->ext_info(); | |||||
| uint32_t ext_size = ext_info.size(); | |||||
| if (ext_info.empty()) { | |||||
| aicpu_param_head.extInfoLength = 0; | |||||
| aicpu_param_head.extInfoAddr = 0; | |||||
| } else { | |||||
| rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); | |||||
| if (flag != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag); | |||||
| return false; | |||||
| } | |||||
| flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size, | |||||
| RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (flag != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); | |||||
| return false; | |||||
| } | |||||
| GELOGI("ext info size: %u", ext_size); | |||||
| aicpu_param_head.extInfoLength = ext_size; | |||||
| aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | |||||
| } | |||||
| // Malloc device memory for args | |||||
| rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task args data.", args_size) | |||||
| // Memcpy AicpuParamHead | |||||
| rt_ret = rtMemcpy(args_, sizeof(aicpu::AicpuParamHead), reinterpret_cast<void *>(&aicpu_param_head), | |||||
| sizeof(aicpu::AicpuParamHead), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||||
| return false; | |||||
| } | |||||
| // Memcpy io addrs | |||||
| if (io_addrs_num != 0) { | |||||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset), io_addrs_size, | |||||
| reinterpret_cast<void *>(io_addrs.data()), io_addrs_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| // Memcpy node def | |||||
| auto size = task_info_->node_def().size(); | |||||
| rt_ret = | |||||
| rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t), | |||||
| reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||||
| return false; | |||||
| } | |||||
| // Memcpy node def | |||||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), | |||||
| task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), | |||||
| task_info_->node_def().size(), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||||
| return false; | |||||
| } | |||||
| input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset); | |||||
| auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||||
| GELOGI( | |||||
| "Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.", | |||||
| args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag); | |||||
| rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()), | |||||
| reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, | |||||
| args_size, nullptr, stream_, dump_flag); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("Distribute AicpuTask end."); | |||||
| return true; | |||||
| } | |||||
| void AicpuTask::ReleaseRtMem(void **ptr) noexcept { | |||||
| if (ptr == nullptr || *ptr == nullptr) { | |||||
| return; | |||||
| } | |||||
| rtError_t rt_ret = rtFree(*ptr); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "ReleaseRtMem failed, ret: 0x%X", rt_ret); | |||||
| return; | |||||
| } | |||||
| *ptr = nullptr; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::AICPU, AicpuTask, AicpuTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,50 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_AICPU_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_AICPU_TASK_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class AicpuTask : public TaskRepeater<AicpuTaskInfo> { | |||||
| public: | |||||
| AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info); | |||||
| ~AicpuTask() override; | |||||
| bool Distribute() override; | |||||
| void *Args() override { return input_output_addr_; } | |||||
| std::string task_name() const override { return task_info_->op_name(); } | |||||
| private: | |||||
| static void ReleaseRtMem(void **ptr) noexcept; | |||||
| std::shared_ptr<AicpuTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| void *args_; | |||||
| void *ext_info_; | |||||
| void *input_output_addr_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_AICPU_TASK_H_ | |||||
| @@ -1,160 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/cce_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| CceTask::CceTask(const ModelContext &model_context, const std::shared_ptr<CceTaskInfo> &task_info) | |||||
| : TaskRepeater<CceTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| stub_func_(nullptr), | |||||
| args_(nullptr), | |||||
| sm_desc_(nullptr), | |||||
| flowtable_(nullptr), | |||||
| is_flowtable_(false) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| if (stream_list.size() == 1) { | |||||
| stream_ = stream_list[0]; | |||||
| } else if (stream_list.size() > task_info->stream_id()) { | |||||
| stream_ = stream_list[task_info->stream_id()]; | |||||
| } else { | |||||
| GELOGW("index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); | |||||
| } | |||||
| } | |||||
| CceTask::~CceTask() { | |||||
| FreeRtMem(&args_); | |||||
| FreeRtMem(&flowtable_); | |||||
| rtError_t ret = (sm_desc_ != nullptr) ? rtMemFreeManaged(sm_desc_) : RT_ERROR_NONE; | |||||
| if (ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||||
| } | |||||
| sm_desc_ = nullptr; | |||||
| } | |||||
| void CceTask::FreeRtMem(void **ptr) noexcept { | |||||
| if (ptr == nullptr || *ptr == nullptr) { | |||||
| return; | |||||
| } | |||||
| rtError_t ret = rtFree(*ptr); | |||||
| if (ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||||
| } | |||||
| *ptr = nullptr; | |||||
| } | |||||
| bool CceTask::Distribute() { | |||||
| GELOGI("Distribute CCETask start."); | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream_ is null!"); | |||||
| return false; | |||||
| } | |||||
| // Get stub_func | |||||
| if (task_info_->stub_func().empty()) { | |||||
| GELOGE(PARAM_INVALID, "kernel_info->stub_func is empty!"); | |||||
| return false; | |||||
| } | |||||
| rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtGetFunctionByName failed, ret: 0x%X", rt_ret); | |||||
| stub_func_ = nullptr; | |||||
| return false; | |||||
| } | |||||
| GELOGI("CCETask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_); | |||||
| // Flowtable | |||||
| if (is_flowtable_) { | |||||
| rt_ret = rtMalloc(&flowtable_, task_info_->flow_table().size(), RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task information.", task_info_->flow_table().size()) | |||||
| rt_ret = rtMemcpy(flowtable_, task_info_->flow_table().size(), task_info_->flow_table().data(), | |||||
| task_info_->flow_table().size(), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| // Modify flowtable addr in args | |||||
| auto args = const_cast<uint8_t *>(task_info_->args().data()); | |||||
| auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data())); | |||||
| if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { | |||||
| GELOGE(FAILED, "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||||
| static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); | |||||
| return false; | |||||
| } | |||||
| *(reinterpret_cast<uintptr_t *>(args + task_offset[0])) = reinterpret_cast<uintptr_t>(flowtable_); | |||||
| } | |||||
| // Args | |||||
| rt_ret = rtMalloc(&args_, task_info_->args_size(), RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task information.", task_info_->args_size()) | |||||
| rt_ret = rtMemcpy(args_, task_info_->args_size(), task_info_->args().data(), task_info_->args_size(), | |||||
| RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| // L2 sm_desc | |||||
| if (!task_info_->sm_desc().empty()) { | |||||
| rt_ret = rtMemAllocManaged(&sm_desc_, task_info_->sm_desc().size(), RT_MEMORY_SPM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), task_info_->sm_desc().data(), | |||||
| task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| // Kernel launch | |||||
| rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, task_info_->args_size(), | |||||
| static_cast<rtSmDesc_t *>(sm_desc_), stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::CCE, CceTask, CceTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,47 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_CCE_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_CCE_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class CceTask : public TaskRepeater<CceTaskInfo> { | |||||
| public: | |||||
| CceTask(const ModelContext &model_context, const std::shared_ptr<CceTaskInfo> &task_info); | |||||
| ~CceTask() override; | |||||
| bool Distribute() override; | |||||
| static void FreeRtMem(void **ptr) noexcept; | |||||
| private: | |||||
| std::shared_ptr<CceTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| void *stub_func_; | |||||
| void *args_; | |||||
| void *sm_desc_; | |||||
| void *flowtable_; | |||||
| bool is_flowtable_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_CCE_TASK_H_ | |||||
| @@ -1,61 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/event_record_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| EventRecordTask::EventRecordTask(const ModelContext &model_context, | |||||
| const std::shared_ptr<EventRecordTaskInfo> &task_info) | |||||
| : TaskRepeater<EventRecordTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| event_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| auto event_list = model_context.event_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| uint32_t event_id = task_info->event_id(); | |||||
| if (stream_id >= stream_list.size() || event_id >= event_list.size()) { | |||||
| GELOGW("stream_list size:%zu, stream_id:%u, event_list size:%zu, event_id:%u", stream_list.size(), stream_id, | |||||
| event_list.size(), event_id); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| event_ = event_list[event_id]; | |||||
| } | |||||
| EventRecordTask::~EventRecordTask() {} | |||||
| bool EventRecordTask::Distribute() { | |||||
| GELOGI("EventRecordTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_, | |||||
| task_info_->stream_id(), task_info_->event_id()); | |||||
| rtError_t rt_ret = rtEventRecord(event_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("Distribute end."); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::EVENT_RECORD, EventRecordTask, EventRecordTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_EVENT_RECORD_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_EVENT_RECORD_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> { | |||||
| public: | |||||
| EventRecordTask(const ModelContext &model_context, const std::shared_ptr<EventRecordTaskInfo> &task_info); | |||||
| ~EventRecordTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<EventRecordTaskInfo> task_info_; | |||||
| rtStream_t stream_; | |||||
| rtEvent_t event_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_EVENT_RECORD_TASK_H_ | |||||
| @@ -1,67 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/event_wait_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| EventWaitTask::EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info) | |||||
| : TaskRepeater<EventWaitTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| event_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| auto event_list = model_context.event_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| uint32_t event_id = task_info->event_id(); | |||||
| if (stream_id >= stream_list.size() || event_id >= event_list.size()) { | |||||
| GELOGW("stream_list size:%zu, stream_id:%u, event_list size:%zu, event_id:%u", stream_list.size(), stream_id, | |||||
| event_list.size(), event_id); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| event_ = event_list[event_id]; | |||||
| } | |||||
| EventWaitTask::~EventWaitTask() {} | |||||
| bool EventWaitTask::Distribute() { | |||||
| GELOGI("EventWaitTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_, | |||||
| task_info_->stream_id(), task_info_->event_id()); | |||||
| rtError_t rt_ret = rtStreamWaitEvent(stream_, event_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtStreamWaitEvent failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtEventReset(event_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtEventReset failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("Distribute end."); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::EVENT_WAIT, EventWaitTask, EventWaitTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_EVENT_WAIT_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_EVENT_WAIT_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> { | |||||
| public: | |||||
| EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info); | |||||
| ~EventWaitTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<EventWaitTaskInfo> task_info_; | |||||
| rtStream_t stream_; | |||||
| rtEvent_t event_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_EVENT_WAIT_TASK_H_ | |||||
| @@ -1,268 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/hccl_task.h" | |||||
| #include <algorithm> | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| #include "common/opskernel/ops_kernel_info_store.h" | |||||
| #include "common/opskernel/ge_task_info.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<HcclTask::StreamGuard>>>> | |||||
| HcclTask::model_stream_mapping_; | |||||
| std::mutex HcclTask::model_stream_mapping_mutex_; | |||||
| HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info) | |||||
| : TaskRepeater<HcclTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| workspace_mem_(nullptr), | |||||
| rt_model_handle_(nullptr), | |||||
| priority_(0), | |||||
| secondary_stream_list_() { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| priority_ = model_context.priority(); | |||||
| rt_model_handle_ = model_context.rt_model_handle(); | |||||
| auto stream_list = model_context.stream_list(); | |||||
| if (stream_list.size() == 1) { | |||||
| stream_ = stream_list[0]; | |||||
| } else if (stream_list.size() > task_info->stream_id()) { | |||||
| stream_ = stream_list[task_info->stream_id()]; | |||||
| } else { | |||||
| GELOGW("Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); | |||||
| } | |||||
| } | |||||
| HcclTask::~HcclTask() { | |||||
| if (workspace_mem_ != nullptr) { | |||||
| rtError_t rt_ret = rtFree(workspace_mem_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtFree workspace_mem_ failed! ret: 0x%X.", rt_ret); | |||||
| } | |||||
| workspace_mem_ = nullptr; | |||||
| } | |||||
| } | |||||
| bool HcclTask::Distribute() { | |||||
| // Ops kernel info store | |||||
| // Get privateDef and opsKernelStorePtr | |||||
| GELOGI("Get custom info in modelTaskDef"); | |||||
| void *ops_kernel_store = task_info_->ops_kernel_store(); | |||||
| OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<OpsKernelInfoStore *>(ops_kernel_store); | |||||
| if (ops_kernel_store == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "No hcom distribute function ptr and no ops kernel store."); | |||||
| return false; | |||||
| } | |||||
| char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data())); | |||||
| auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size()); | |||||
| GELOGI("The first address of the custom info, privateDef=%p", private_def); | |||||
| SetSecondaryStream(); | |||||
| if (task_info_->workspace_size() > 0) { | |||||
| rtError_t rt_ret = rtMalloc(&workspace_mem_, task_info_->workspace_size(), RT_MEMORYINFO_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl."); | |||||
| GETaskInfo ge_task; | |||||
| ge_task.id = 0; | |||||
| ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | |||||
| ge_task.stream = stream_; | |||||
| ge_task.kernelHcclInfo = std::vector<GETaskKernelHcclInfo>(1); | |||||
| ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | |||||
| ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | |||||
| ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | |||||
| ge_task.kernelHcclInfo[0].workSpaceAddr = workspace_mem_; | |||||
| ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size(); | |||||
| ge_task.kernelHcclInfo[0].count = task_info_->count(); | |||||
| ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type()); | |||||
| ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type()); | |||||
| ge_task.kernelHcclInfo[0].rootId = task_info_->root_id(); | |||||
| std::vector<rtStream_t> secondary_stream_list; | |||||
| std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(), | |||||
| std::back_inserter(secondary_stream_list), | |||||
| [](const std::shared_ptr<StreamGuard> &stream) -> rtStream_t { return stream->GetStream(); }); | |||||
| ge_task.kernelHcclInfo[0].hcclStreamList = secondary_stream_list; | |||||
| ge_task.privateDef = private_def; | |||||
| ge_task.privateDefLen = private_def_len; | |||||
| ge_task.opsKernelStorePtr = ops_kernel_store; | |||||
| auto result = ops_kernel_info_store->LoadTask(ge_task); | |||||
| // tagHcclResult::HCCL_SUCCESS is 0 | |||||
| if (result != 0) { | |||||
| GELOGE(INTERNAL_ERROR, "davinci_model : load task fail, return ret: %u", result); | |||||
| return false; | |||||
| } | |||||
| GELOGI("Call function LoadTask end."); | |||||
| return true; | |||||
| } | |||||
| bool HcclTask::SetSecondaryStream() { | |||||
| const uint32_t master_stream_id = task_info_->stream_id(); | |||||
| const int64_t hccl_secondary_stream_num = task_info_->hccl_stream_num(); | |||||
| Status ret; | |||||
| std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_); | |||||
| if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | |||||
| GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %u.", rt_model_handle_, master_stream_id); | |||||
| ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | |||||
| if (!ret) { | |||||
| GELOGE(RT_FAILED, "Create hccl stream failed."); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map = | |||||
| model_stream_mapping_.at(rt_model_handle_); | |||||
| auto iter = master_secondary_stream_map.find(master_stream_id); | |||||
| if (iter != master_secondary_stream_map.end()) { | |||||
| std::vector<std::weak_ptr<StreamGuard>> &secondary_stream_vec = iter->second; | |||||
| auto lock_weak_ptr = [&secondary_stream_vec, this](int64_t index) -> bool { | |||||
| auto stream = secondary_stream_vec[index].lock(); | |||||
| if (stream == nullptr) { | |||||
| rtStream_t new_stream = nullptr; | |||||
| bool ret = CreateStream(rt_model_handle_, &new_stream); | |||||
| if (!ret) { | |||||
| GELOGE(FAILED, "CreateStream failed."); | |||||
| return false; | |||||
| } | |||||
| stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream); | |||||
| GE_RT_FALSE_CHECK_NOTNULL(stream); | |||||
| secondary_stream_vec[index] = stream; | |||||
| } | |||||
| secondary_stream_list_.push_back(stream); | |||||
| return true; | |||||
| }; | |||||
| if (static_cast<size_t>(hccl_secondary_stream_num) <= secondary_stream_vec.size()) { | |||||
| GELOGI("Number of secondary stream is enough to be reused."); | |||||
| for (int64_t i = 0; i < hccl_secondary_stream_num; ++i) { | |||||
| if (!lock_weak_ptr(i)) { | |||||
| GELOGE(FAILED, "Lock weak ptr failed."); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| GELOGI("Need to reuse secondary stream and create new secondary stream."); | |||||
| size_t created_stream_num = secondary_stream_vec.size(); | |||||
| for (size_t i = 0; i < secondary_stream_vec.size(); ++i) { | |||||
| if (!lock_weak_ptr(i)) { | |||||
| GELOGE(FAILED, "Lock weak ptr failed."); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| ret = CreateStream(hccl_secondary_stream_num - created_stream_num, master_stream_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(RT_FAILED, "Create hccl stream failed."); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num); | |||||
| } else { | |||||
| GELOGI("Need to create secondary stream for %s with new mainstream %u.", task_info_->op_name().c_str(), | |||||
| master_stream_id); | |||||
| ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | |||||
| if (!ret) { | |||||
| GELOGE(RT_FAILED, "Create hccl stream failed."); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool HcclTask::CreateStream(int64_t stream_num, int64_t master_stream_id) { | |||||
| GELOGI("Start to create %ld hccl secondary stream.", stream_num); | |||||
| for (int64_t i = 0; i < stream_num; ++i) { | |||||
| rtStream_t stream = nullptr; | |||||
| bool ret = CreateStream(rt_model_handle_, &stream); | |||||
| if (!ret) { | |||||
| GELOGE(FAILED, "CreateStream failed."); | |||||
| return false; | |||||
| } | |||||
| GELOGD("hccl_stream addr is=%p", stream); | |||||
| auto shared_stream = std::make_shared<StreamGuard>(rt_model_handle_, stream); | |||||
| if (shared_stream == nullptr) { | |||||
| GELOGE(FAILED, "MakeShared failed."); | |||||
| return false; | |||||
| } | |||||
| SaveHcclSecondaryStream(master_stream_id, shared_stream); | |||||
| secondary_stream_list_.push_back(shared_stream); | |||||
| } | |||||
| GELOGI("CreateStream success."); | |||||
| return true; | |||||
| } | |||||
| bool HcclTask::CreateStream(rtModel_t model, rtStream_t *stream) const { | |||||
| if (stream == nullptr) { | |||||
| GELOGE(FAILED, "Output param stream is null."); | |||||
| return false; | |||||
| } | |||||
| rtError_t rt_ret = rtStreamCreateWithFlags(stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| // Create secondary stream, inactive by default, activated by hccl | |||||
| rt_ret = rtModelBindStream(model, *stream, RT_MODEL_WAIT_ACTIVE_STREAM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void HcclTask::SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream) { | |||||
| if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | |||||
| model_stream_mapping_.emplace(rt_model_handle_, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>()); | |||||
| } | |||||
| std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map = | |||||
| model_stream_mapping_.at(rt_model_handle_); | |||||
| master_secondary_stream_map[master_stream_id].emplace_back(stream); | |||||
| } | |||||
| HcclTask::StreamGuard::~StreamGuard() { | |||||
| rtError_t rt_ret = rtModelUnbindStream(model_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Unbind stream from model failed!"); | |||||
| return; | |||||
| } | |||||
| rt_ret = rtStreamDestroy(stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Destroy stream failed!"); | |||||
| return; | |||||
| } | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::HCCL, HcclTask, HcclTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,69 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_HCCL_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_HCCL_TASK_H_ | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <mutex> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class HcclTask : public TaskRepeater<HcclTaskInfo> { | |||||
| public: | |||||
| HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info); | |||||
| ~HcclTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| class StreamGuard; | |||||
| bool SetSecondaryStream(); | |||||
| bool CreateStream(int64_t stream_num, int64_t master_stream_id); | |||||
| bool CreateStream(rtModel_t model, rtStream_t *stream) const; | |||||
| void SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream); | |||||
| std::shared_ptr<HcclTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| void *workspace_mem_; | |||||
| rtModel_t rt_model_handle_; | |||||
| int32_t priority_; | |||||
| std::vector<std::shared_ptr<StreamGuard>> secondary_stream_list_; | |||||
| // map<key: model pointer, value: map<key: primary stream id, value: vector<secondary stream pointer>>> | |||||
| static std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>> model_stream_mapping_; | |||||
| static std::mutex model_stream_mapping_mutex_; | |||||
| }; | |||||
| class HcclTask::StreamGuard { | |||||
| public: | |||||
| StreamGuard(rtModel_t model, rtStream_t stream) : model_(model), stream_(stream) {} | |||||
| ~StreamGuard(); | |||||
| rtStream_t GetStream() const { return stream_; } | |||||
| private: | |||||
| rtModel_t model_; | |||||
| rtStream_t stream_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_HCCL_TASK_H_ | |||||
| @@ -1,117 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/label_goto_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| #include "framework/common/util.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| auto label_list = model_context.label_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| uint32_t label_id = task_info->label_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
| GELOGW("Stream/Label id invalid."); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| label_ = label_list[label_id]; | |||||
| } | |||||
| LabelGotoTask::~LabelGotoTask() { | |||||
| GE_FREE_RT_LOG(label_info_); | |||||
| GE_FREE_RT_LOG(index_value_); | |||||
| } | |||||
| bool LabelGotoTask::Distribute() { | |||||
| GELOGI("LabelGotoTask Distribute start."); | |||||
| if (!CheckParamValid()) { | |||||
| return false; | |||||
| } | |||||
| const std::vector<void *> label_list = { label_ }; | |||||
| rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
| return false; | |||||
| } | |||||
| uint64_t branch_index = 0; | |||||
| rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
| return false; | |||||
| } | |||||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||||
| rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end."); | |||||
| return true; | |||||
| } | |||||
| bool LabelGotoTask::CheckParamValid() { | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||||
| return false; | |||||
| } | |||||
| if (label_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label is null!"); | |||||
| return false; | |||||
| } | |||||
| if (label_info_ != nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||||
| return false; | |||||
| } | |||||
| if (index_value_ != nullptr) { | |||||
| GELOGE(PARAM_INVALID, "index_value_ has dirty data."); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,45 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||||
| public: | |||||
| LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||||
| ~LabelGotoTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| bool CheckParamValid(); | |||||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||||
| void *stream_{nullptr}; | |||||
| void *label_{nullptr}; | |||||
| void *label_info_{nullptr}; | |||||
| void *index_value_{nullptr}; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
| @@ -1,70 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/label_set_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||||
| : TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| label_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| auto label_list = model_context.label_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| uint32_t label_id = task_info->label_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
| GELOGW("Stream/Label id invalid."); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| label_ = label_list[label_id]; | |||||
| } | |||||
| LabelSetTask::~LabelSetTask() {} | |||||
| bool LabelSetTask::Distribute() { | |||||
| GELOGI("LabelSetTask Distribute start."); | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||||
| return false; | |||||
| } | |||||
| if (label_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label is null!"); | |||||
| return false; | |||||
| } | |||||
| rtError_t rt_ret = rtLabelSet(label_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end."); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||||
| public: | |||||
| LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||||
| ~LabelSetTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<LabelSetTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| void *label_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
| @@ -1,131 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/label_switch_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||||
| const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||||
| : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| all_label_resource_(), | |||||
| label_info_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| all_label_resource_ = model_context.label_list(); | |||||
| auto stream_list = model_context.stream_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| if (stream_id >= stream_list.size()) { | |||||
| GELOGW("Stream id invalid."); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| } | |||||
| LabelSwitchTask::~LabelSwitchTask() { | |||||
| if (label_info_ != nullptr) { | |||||
| rtError_t rt_ret = rtFree(label_info_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||||
| } | |||||
| label_info_ = nullptr; | |||||
| } | |||||
| } | |||||
| bool LabelSwitchTask::Distribute() { | |||||
| GELOGI("LabelSwitchTask Distribute start."); | |||||
| if (!CheckParamValid()) { | |||||
| return false; | |||||
| } | |||||
| const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||||
| std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||||
| for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||||
| uint32_t label_index = label_index_list[i]; | |||||
| if (label_index >= all_label_resource_.size()) { | |||||
| GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||||
| all_label_resource_.size()); | |||||
| return false; | |||||
| } | |||||
| label_list[i] = all_label_resource_[label_index]; | |||||
| GELOGI("Case %zu: label id %zu.", i, (size_t)label_index); | |||||
| } | |||||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||||
| rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end."); | |||||
| return true; | |||||
| } | |||||
| bool LabelSwitchTask::CheckParamValid() { | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||||
| return false; | |||||
| } | |||||
| if (task_info_->label_list().empty()) { | |||||
| GELOGE(PARAM_INVALID, "label_list is empty."); | |||||
| return false; | |||||
| } | |||||
| if (task_info_->label_size() != task_info_->label_list().size()) { | |||||
| GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||||
| task_info_->label_size()); | |||||
| return false; | |||||
| } | |||||
| if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||||
| GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||||
| return false; | |||||
| } | |||||
| if (label_info_ != nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||||
| public: | |||||
| LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||||
| ~LabelSwitchTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| bool CheckParamValid(); | |||||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| std::vector<void *> all_label_resource_; | |||||
| void *label_info_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
| @@ -1,57 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/memcpy_async_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| MemcpyAsyncTask::MemcpyAsyncTask(const ModelContext &model_context, | |||||
| const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info) | |||||
| : TaskRepeater<MemcpyAsyncTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| if (stream_id >= stream_list.size()) { | |||||
| GELOGW("Stream id invalid"); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| } | |||||
| MemcpyAsyncTask::~MemcpyAsyncTask() {} | |||||
| bool MemcpyAsyncTask::Distribute() { | |||||
| GELOGI("MemcpyAsyncTask Distribute start."); | |||||
| GELOGI("dst_max:%lu, count:%lu, kind:%u.", task_info_->dst_max(), task_info_->count(), task_info_->kind()); | |||||
| rtError_t rt_ret = rtMemcpyAsync(task_info_->dst(), task_info_->dst_max(), task_info_->src(), task_info_->count(), | |||||
| static_cast<rtMemcpyKind_t>(task_info_->kind()), stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end"); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::MEMCPY_ASYNC, MemcpyAsyncTask, MemcpyAsyncTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,40 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class MemcpyAsyncTask : public TaskRepeater<MemcpyAsyncTaskInfo> { | |||||
| public: | |||||
| MemcpyAsyncTask(const ModelContext &model_context, const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info); | |||||
| ~MemcpyAsyncTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<MemcpyAsyncTaskInfo> task_info_; | |||||
| rtStream_t stream_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ | |||||
| @@ -1,55 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/profiler_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info) | |||||
| : TaskRepeater<ProfilerTraceTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| if (stream_id >= stream_list.size()) { | |||||
| GELOGW("Stream id invalid"); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| } | |||||
| ProfilerTask::~ProfilerTask() {} | |||||
| bool ProfilerTask::Distribute() { | |||||
| GELOGI("ProfilerTask Distribute start."); | |||||
| GELOGI("logid = %lu, notify = %d, flat = %u.", task_info_->log_id(), task_info_->notify(), task_info_->flat()); | |||||
| rtError_t rt_ret = rtProfilerTrace(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end"); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::PROFILER_TRACE, ProfilerTask, ProfilerTraceTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,40 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_PROFILER_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_PROFILER_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class ProfilerTask : public TaskRepeater<ProfilerTraceTaskInfo> { | |||||
| public: | |||||
| ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info); | |||||
| ~ProfilerTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<ProfilerTraceTaskInfo> task_info_; | |||||
| rtStream_t stream_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_PROFILER_TASK_H_ | |||||
| @@ -1,60 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/stream_active_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| StreamActiveTask::StreamActiveTask(const ModelContext &model_context, | |||||
| const std::shared_ptr<StreamActiveTaskInfo> &task_info) | |||||
| : TaskRepeater<StreamActiveTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| active_stream_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| uint32_t active_stream_id = task_info->active_stream_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u, active stream id:%u", stream_list.size(), stream_id, active_stream_id); | |||||
| if (stream_id >= stream_list.size() || active_stream_id >= stream_list.size()) { | |||||
| GELOGW("Stream id invalid"); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| active_stream_ = stream_list[active_stream_id]; | |||||
| } | |||||
| StreamActiveTask::~StreamActiveTask() {} | |||||
| bool StreamActiveTask::Distribute() { | |||||
| GELOGI("Distribute start"); | |||||
| GELOGI("Stream %u active %u.", task_info_->stream_id(), task_info_->active_stream_id()); | |||||
| rtError_t rt_ret = rtStreamActive(active_stream_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end"); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::STREAM_ACTIVE, StreamActiveTask, StreamActiveTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class StreamActiveTask : public TaskRepeater<StreamActiveTaskInfo> { | |||||
| public: | |||||
| StreamActiveTask(const ModelContext &model_context, const std::shared_ptr<StreamActiveTaskInfo> &task_info); | |||||
| ~StreamActiveTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<StreamActiveTaskInfo> task_info_; | |||||
| rtStream_t stream_; | |||||
| rtStream_t active_stream_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ | |||||
| @@ -1,82 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/stream_switch_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| StreamSwitchTask::StreamSwitchTask(const ModelContext &model_context, | |||||
| const std::shared_ptr<StreamSwitchTaskInfo> &task_info) | |||||
| : TaskRepeater<StreamSwitchTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| stream_list_() { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| stream_list_ = model_context.stream_list(); | |||||
| if (stream_list_.size() == 1) { | |||||
| stream_ = stream_list_[0]; | |||||
| } else if (stream_list_.size() > task_info->stream_id()) { | |||||
| stream_ = stream_list_[task_info->stream_id()]; | |||||
| } else { | |||||
| GELOGW("Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list_.size()); | |||||
| } | |||||
| } | |||||
| StreamSwitchTask::~StreamSwitchTask() {} | |||||
| bool StreamSwitchTask::Distribute() { | |||||
| GELOGI("Init StreamSwitchTask start."); | |||||
| GELOGI("Stream %u active %ld.", task_info_->stream_id(), task_info_->true_stream_id()); | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream_ is null!"); | |||||
| return false; | |||||
| } | |||||
| if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) { | |||||
| GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(), | |||||
| stream_list_.size()); | |||||
| return false; | |||||
| } | |||||
| void *input = reinterpret_cast<void *>(task_info_->input_addr()); | |||||
| rtCondition_t cond = static_cast<rtCondition_t>(task_info_->cond()); | |||||
| void *value = reinterpret_cast<void *>(task_info_->value_addr()); | |||||
| rtStream_t true_stream = stream_list_[task_info_->true_stream_id()]; | |||||
| rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type()); | |||||
| GELOGI("InitStreamSwitchTask, cond:%d, trueStream:%p, trueStreamID:%ld, datatype:%ld.", cond, true_stream, | |||||
| task_info_->true_stream_id(), task_info_->data_type()); | |||||
| GELOGI("StreamSwitchTask Distribute Start."); | |||||
| rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("Distribute StreamSwitch, cond:%d, trueStream:%p, datatype:%ld.", cond, true_stream, task_info_->data_type()); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::STREAM_SWITCH, StreamSwitchTask, StreamSwitchTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> { | |||||
| public: | |||||
| StreamSwitchTask(const ModelContext &model_context, const std::shared_ptr<StreamSwitchTaskInfo> &task_info); | |||||
| ~StreamSwitchTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<StreamSwitchTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| std::vector<rtStream_t> stream_list_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||||
| @@ -1,58 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_TASK_H_ | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "runtime/rt_model.h" | |||||
| #include "ge_runtime/model_context.h" | |||||
| #include "ge_runtime/task_info.h" | |||||
| #include "external/runtime/rt_error_codes.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class Task { | |||||
| public: | |||||
| Task() {} | |||||
| virtual ~Task() {} | |||||
| virtual bool Distribute() = 0; | |||||
| virtual void *Args() { return nullptr; } | |||||
| virtual std::string task_name() const { return ""; } | |||||
| }; | |||||
| template <class T> | |||||
| class TaskRepeater : public Task { | |||||
| static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); | |||||
| public: | |||||
| TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {} | |||||
| virtual ~TaskRepeater() {} | |||||
| virtual bool Distribute() = 0; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_TASK_H_ | |||||
| @@ -1,87 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||||
| #define GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||||
| #include <functional> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "ge_runtime/task_info.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class Task; | |||||
| class ModelContext; | |||||
| using TASK_CREATOR_FUN = std::function<std::shared_ptr<Task>(const ModelContext &, std::shared_ptr<TaskInfo>)>; | |||||
| class TaskFactory { | |||||
| private: | |||||
| TaskFactory() {} | |||||
| ~TaskFactory() {} | |||||
| void RegisterCreator(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { | |||||
| if (creator_map_.find(type) != creator_map_.end()) { | |||||
| GELOGW("Creator type %d already exist", static_cast<int32_t>(type)); | |||||
| } | |||||
| creator_map_[type] = func; | |||||
| } | |||||
| std::map<TaskInfoType, TASK_CREATOR_FUN> creator_map_; | |||||
| public: | |||||
| static TaskFactory &GetInstance() { | |||||
| static TaskFactory instance; | |||||
| return instance; | |||||
| } | |||||
| std::shared_ptr<Task> Create(const ModelContext &model_context, std::shared_ptr<TaskInfo> &task_info) const { | |||||
| if (task_info == nullptr) { | |||||
| GELOGE(FAILED, "task_info is null."); | |||||
| return nullptr; | |||||
| } | |||||
| auto iter = creator_map_.find(task_info->type()); | |||||
| if (iter == creator_map_.end()) { | |||||
| GELOGE(FAILED, "Unknow task type %d", static_cast<int32_t>(task_info->type())); | |||||
| return nullptr; | |||||
| } | |||||
| return iter->second(model_context, task_info); | |||||
| } | |||||
| class Register { | |||||
| public: | |||||
| Register(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { | |||||
| GELOGI("regist type %d", static_cast<int32_t>(type)); | |||||
| TaskFactory::GetInstance().RegisterCreator(type, func); | |||||
| } | |||||
| ~Register() {} | |||||
| }; | |||||
| }; | |||||
| #define REGISTER_TASK(type, task_clazz, task_info_clazz) \ | |||||
| TaskFactory::Register g_##task_clazz##_register( \ | |||||
| type, [](const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) -> std::shared_ptr<Task> { \ | |||||
| std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ | |||||
| return std::make_shared<task_clazz>(model_context, concrete_task_info); \ | |||||
| }); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||||
| @@ -1,112 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/tbe_task.h" | |||||
| #include <vector> | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info) | |||||
| : TaskRepeater<TbeTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| stub_func_(nullptr), | |||||
| args_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| if (stream_list.size() == 1) { | |||||
| stream_ = stream_list[0]; | |||||
| } else if (stream_list.size() > task_info->stream_id()) { | |||||
| stream_ = stream_list[task_info->stream_id()]; | |||||
| } else { | |||||
| GELOGE(PARAM_INVALID, "Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); | |||||
| return; | |||||
| } | |||||
| } | |||||
| TbeTask::~TbeTask() { | |||||
| if (args_ != nullptr) { | |||||
| rtError_t rt_ret = rtFree(args_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||||
| } | |||||
| args_ = nullptr; | |||||
| } | |||||
| } | |||||
| bool TbeTask::Distribute() { | |||||
| GELOGI("InitTbeTask start."); | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream_ is null!"); | |||||
| return false; | |||||
| } | |||||
| // Get stub_func | |||||
| if (task_info_->stub_func().empty()) { | |||||
| GELOGE(PARAM_INVALID, "kernel_info->stub_func is empty!"); | |||||
| return false; | |||||
| } | |||||
| rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtGetFunctionByName failed, ret: %d", static_cast<int32_t>(rt_ret)); | |||||
| stub_func_ = nullptr; | |||||
| return false; | |||||
| } | |||||
| GELOGI("TbeTask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_); | |||||
| // Get args | |||||
| std::vector<void *> tensor_device_addrs; | |||||
| tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->input_data_addrs().begin(), | |||||
| task_info_->input_data_addrs().end()); | |||||
| tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->output_data_addrs().begin(), | |||||
| task_info_->output_data_addrs().end()); | |||||
| tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->workspace_addrs().begin(), | |||||
| task_info_->workspace_addrs().end()); | |||||
| auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *)); | |||||
| rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtMalloc failed, ret: %d", static_cast<int32_t>(rt_ret)); | |||||
| return false; | |||||
| } | |||||
| GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task args data.", args_size) | |||||
| rt_ret = rtMemcpy(args_, args_size, reinterpret_cast<void *>(tensor_device_addrs.data()), args_size, | |||||
| RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtMemcpy fail, ret 0x%X.", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTbeTask start."); | |||||
| auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||||
| rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::TBE, TbeTask, TbeTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -1,46 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_TBE_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_TBE_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class TbeTask : public TaskRepeater<TbeTaskInfo> { | |||||
| public: | |||||
| TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info); | |||||
| ~TbeTask() override; | |||||
| bool Distribute() override; | |||||
| void *Args() override { return args_; } | |||||
| std::string task_name() const override { return task_info_->op_name(); } | |||||
| private: | |||||
| std::shared_ptr<TbeTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| void *stub_func_; | |||||
| void *args_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_TBE_TASK_H_ | |||||
| @@ -1,113 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ | |||||
| #define INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "ge_runtime/op_info.h" | |||||
| #include "ge_runtime/task_info.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class DavinciModel { | |||||
| public: | |||||
| DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, | |||||
| const std::vector<std::shared_ptr<OpInfo>> &data_info_list, | |||||
| const std::vector<std::shared_ptr<OpInfo>> &output_info_list, | |||||
| const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, | |||||
| const std::vector<model_runner::OpInfoPtr> &variable_info_list, | |||||
| const std::vector<uint32_t> &wait_active_stream_list, | |||||
| const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | |||||
| uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0, | |||||
| uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0, | |||||
| int32_t priority = 0) | |||||
| : task_info_list_(task_info_list), | |||||
| data_info_list_(data_info_list), | |||||
| output_info_list_(output_info_list), | |||||
| constant_info_list_(constant_info_list), | |||||
| variable_info_list_(variable_info_list), | |||||
| wait_active_stream_list_(wait_active_stream_list), | |||||
| force_copy_stream_list_(force_copy_stream_list), | |||||
| mem_size_(mem_size), | |||||
| weight_size_(weight_size), | |||||
| var_size_(var_size), | |||||
| logic_mem_base_(logic_mem_base), | |||||
| logic_weight_base_(logic_weight_base), | |||||
| logic_var_base_(logic_var_base), | |||||
| stream_num_(stream_num), | |||||
| batch_num_(batch_num), | |||||
| event_num_(event_num), | |||||
| priority_(priority) {} | |||||
| ~DavinciModel() {} | |||||
| uint64_t GetMemSize() const { return mem_size_; } | |||||
| uint64_t GetWeightSize() const { return weight_size_; } | |||||
| uint64_t GetVarSize() const { return var_size_; } | |||||
| uintptr_t GetLogicMemBase() const { return logic_mem_base_; } | |||||
| uintptr_t GetLogicWeightBase() const { return logic_weight_base_; } | |||||
| uintptr_t GetLogicVarBase() const { return logic_var_base_; } | |||||
| uint32_t GetStreamNum() const { return stream_num_; } | |||||
| uint32_t GetBatchNum() const { return batch_num_; } | |||||
| uint32_t GetEventNum() const { return event_num_; } | |||||
| const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } | |||||
| const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } | |||||
| int32_t GetPriority() const { return priority_; } | |||||
| const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } | |||||
| const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } | |||||
| const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } | |||||
| const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } | |||||
| const std::vector<model_runner::OpInfoPtr> &GetVariableInfoList() const { return variable_info_list_; } | |||||
| private: | |||||
| std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | |||||
| std::vector<std::shared_ptr<OpInfo>> data_info_list_; | |||||
| std::vector<std::shared_ptr<OpInfo>> output_info_list_; | |||||
| std::vector<std::shared_ptr<OpInfo>> constant_info_list_; | |||||
| std::vector<model_runner::OpInfoPtr> variable_info_list_; | |||||
| std::vector<uint32_t> wait_active_stream_list_; | |||||
| std::vector<uint32_t> force_copy_stream_list_; | |||||
| uint64_t mem_size_; | |||||
| uint64_t weight_size_; | |||||
| uint64_t var_size_; | |||||
| uintptr_t logic_mem_base_; | |||||
| uintptr_t logic_weight_base_; | |||||
| uintptr_t logic_var_base_; | |||||
| uint32_t stream_num_; | |||||
| uint32_t batch_num_; | |||||
| uint32_t event_num_; | |||||
| int32_t priority_; | |||||
| // Disable to copy constructor and assignment operator | |||||
| DavinciModel &operator=(const DavinciModel &) = delete; | |||||
| DavinciModel(const DavinciModel &) = delete; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ | |||||
| @@ -1,68 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ | |||||
| #define INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "common/ge_types.h" | |||||
| #include "ge_runtime/davinci_model.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class RuntimeModel; | |||||
| using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||||
| class ModelRunner { | |||||
| public: | |||||
| static ModelRunner &Instance(); | |||||
| bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||||
| std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener); | |||||
| bool DistributeTask(uint32_t model_id); | |||||
| bool LoadModelComplete(uint32_t model_id); | |||||
| const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | |||||
| const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const; | |||||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const; | |||||
| void *GetModelHandle(uint32_t model_id) const; | |||||
| bool UnloadModel(uint32_t model_id); | |||||
| bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | |||||
| bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||||
| std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||||
| std::vector<uint32_t> *output_format); | |||||
| private: | |||||
| ModelRunner() = default; | |||||
| ~ModelRunner() = default; | |||||
| std::unordered_map<uint32_t, std::shared_ptr<RuntimeModel>> runtime_models_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ | |||||
| @@ -1,72 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ | |||||
| #define INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| struct TensorInfo { | |||||
| int64_t GetShapeSize() const { | |||||
| int64_t res = 1; | |||||
| if (dims.empty()) { | |||||
| return 0; | |||||
| } | |||||
| for (auto dim : dims) { | |||||
| res *= dim; | |||||
| } | |||||
| return res; | |||||
| } | |||||
| int64_t GetDim(uint32_t index) { | |||||
| if (index >= dims.size()) { | |||||
| return 0; | |||||
| } | |||||
| return dims[index]; | |||||
| } | |||||
| std::vector<int64_t> dims; | |||||
| uint32_t datatype; | |||||
| uint32_t format; | |||||
| uint32_t real_dim_cnt; | |||||
| uint32_t size; | |||||
| bool is_output; | |||||
| }; | |||||
| struct OpInfo { | |||||
| uint32_t index; | |||||
| std::string name; | |||||
| std::string type; | |||||
| bool var_is_broadcast; | |||||
| std::vector<uintptr_t> input_addrs; | |||||
| std::vector<uintptr_t> output_addrs; | |||||
| std::vector<TensorInfo> input_tensors; | |||||
| std::vector<TensorInfo> output_tensors; | |||||
| std::vector<TensorInfo> weight_tensors; | |||||
| std::vector<std::string> src_name; | |||||
| std::vector<int64_t> src_index; | |||||
| std::string weight_data; | |||||
| }; | |||||
| using TensorInfoPtr = std::shared_ptr<TensorInfo>; | |||||
| using OpInfoPtr = std::shared_ptr<OpInfo>; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ | |||||
| @@ -1,405 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ | |||||
| #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ | |||||
| #include <stdint.h> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "cce/taskdown_api.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| enum TaskInfoType { | |||||
| CCE = 0, | |||||
| TBE, | |||||
| AICPU, | |||||
| LABEL_SET, | |||||
| LABEL_SWITCH, | |||||
| LABEL_GOTO, | |||||
| EVENT_RECORD, | |||||
| EVENT_WAIT, | |||||
| FUSION_START, | |||||
| FUSION_END, | |||||
| HCCL, | |||||
| PROFILER_TRACE, | |||||
| MEMCPY_ASYNC, | |||||
| STREAM_SWITCH, | |||||
| STREAM_ACTIVE, | |||||
| // Insert new task type here | |||||
| REVSERVED = 23 | |||||
| }; | |||||
| class TaskInfo { | |||||
| public: | |||||
| virtual ~TaskInfo() {} | |||||
| uint32_t stream_id() const { return stream_id_; } | |||||
| TaskInfoType type() const { return type_; } | |||||
| std::string op_name() const { return op_name_; } | |||||
| bool dump_flag() const { return dump_flag_; } | |||||
| protected: | |||||
| TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) | |||||
| : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} | |||||
| private: | |||||
| std::string op_name_; | |||||
| uint32_t stream_id_; | |||||
| TaskInfoType type_; | |||||
| bool dump_flag_; | |||||
| }; | |||||
| class CceTaskInfo : public TaskInfo { | |||||
| public: | |||||
| CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, | |||||
| uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size, | |||||
| const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table, | |||||
| const std::vector<uint8_t> &args_offset, bool is_flowtable) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false), | |||||
| ctx_(ctx), | |||||
| stub_func_(stub_func), | |||||
| block_dim_(block_dim), | |||||
| args_(args), | |||||
| args_size_(args_size), | |||||
| sm_desc_(sm_desc), | |||||
| flow_table_(flow_table), | |||||
| args_offset_(args_offset), | |||||
| is_flowtable_(is_flowtable) {} | |||||
| ~CceTaskInfo() override {} | |||||
| cce::ccOpContext cc_context() const { return ctx_; } | |||||
| std::string stub_func() const { return stub_func_; } | |||||
| uint32_t block_dim() const { return block_dim_; } | |||||
| const std::vector<uint8_t> &args() const { return args_; } | |||||
| uint32_t args_size() const { return args_size_; } | |||||
| const std::vector<uint8_t> &sm_desc() const { return sm_desc_; } | |||||
| const std::vector<uint8_t> &flow_table() const { return flow_table_; } | |||||
| const std::vector<uint8_t> &args_offset() const { return args_offset_; } | |||||
| bool is_flowtable() const { return is_flowtable_; } | |||||
| private: | |||||
| cce::ccOpContext ctx_; | |||||
| std::string stub_func_; | |||||
| uint32_t block_dim_; | |||||
| std::vector<uint8_t> args_; | |||||
| uint32_t args_size_; | |||||
| std::vector<uint8_t> sm_desc_; | |||||
| std::vector<uint8_t> flow_table_; | |||||
| std::vector<uint8_t> args_offset_; | |||||
| bool is_flowtable_; | |||||
| }; | |||||
| class TbeTaskInfo : public TaskInfo { | |||||
| public: | |||||
| TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, | |||||
| const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, | |||||
| uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||||
| const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), | |||||
| stub_func_(stub_func), | |||||
| block_dim_(block_dim), | |||||
| args_(args), | |||||
| args_size_(args_size), | |||||
| sm_desc_(sm_desc), | |||||
| binary_(binary), | |||||
| binary_size_(binary_size), | |||||
| meta_data_(meta_data), | |||||
| input_data_addrs_(input_data_addrs), | |||||
| output_data_addrs_(output_data_addrs), | |||||
| workspace_addrs_(workspace_addrs) {} | |||||
| ~TbeTaskInfo() override {} | |||||
| const std::string &stub_func() const { return stub_func_; } | |||||
| uint32_t block_dim() const { return block_dim_; } | |||||
| const std::vector<uint8_t> &args() const { return args_; } | |||||
| uint32_t args_size() const { return args_size_; } | |||||
| const std::vector<uint8_t> &sm_desc() const { return sm_desc_; } | |||||
| void *binary() const { return binary_; } | |||||
| uint32_t binary_size() const { return binary_size_; } | |||||
| const std::vector<uint8_t> &meta_data() const { return meta_data_; } | |||||
| const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||||
| const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||||
| const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; } | |||||
| void SetBinary(void *binary, uint32_t binary_size) { | |||||
| binary_ = binary; | |||||
| binary_size_ = binary_size; | |||||
| } | |||||
| private: | |||||
| std::string stub_func_; | |||||
| uint32_t block_dim_; | |||||
| std::vector<uint8_t> args_; | |||||
| uint32_t args_size_; | |||||
| std::vector<uint8_t> sm_desc_; | |||||
| void *binary_; | |||||
| uint32_t binary_size_; | |||||
| std::vector<uint8_t> meta_data_; | |||||
| std::vector<void *> input_data_addrs_; | |||||
| std::vector<void *> output_data_addrs_; | |||||
| std::vector<void *> workspace_addrs_; | |||||
| }; | |||||
| class AicpuTaskInfo : public TaskInfo { | |||||
| public: | |||||
| AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | |||||
| const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs, | |||||
| const std::vector<void *> &output_data_addrs, bool dump_flag) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | |||||
| so_name_(so_name), | |||||
| kernel_name_(kernel_name), | |||||
| node_def_(node_def), | |||||
| ext_info_(ext_info), | |||||
| input_data_addrs_(input_data_addrs), | |||||
| output_data_addrs_(output_data_addrs) {} | |||||
| ~AicpuTaskInfo() override {} | |||||
| const std::string &so_name() const { return so_name_; } | |||||
| const std::string &kernel_name() const { return kernel_name_; } | |||||
| const std::string &node_def() const { return node_def_; } | |||||
| const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||||
| const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||||
| const std::string &ext_info() const { return ext_info_; } | |||||
| private: | |||||
| std::string so_name_; | |||||
| std::string kernel_name_; | |||||
| std::string node_def_; | |||||
| std::string ext_info_; | |||||
| std::vector<void *> input_data_addrs_; | |||||
| std::vector<void *> output_data_addrs_; | |||||
| }; | |||||
| class LabelSetTaskInfo : public TaskInfo { | |||||
| public: | |||||
| LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} | |||||
| ~LabelSetTaskInfo() override {} | |||||
| uint32_t label_id() const { return label_id_; } | |||||
| private: | |||||
| uint32_t label_id_; | |||||
| }; | |||||
| class LabelGotoTaskInfo : public TaskInfo { | |||||
| public: | |||||
| LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} | |||||
| ~LabelGotoTaskInfo() override {} | |||||
| uint32_t label_id() const { return label_id_; } | |||||
| private: | |||||
| uint32_t label_id_; | |||||
| }; | |||||
| class LabelSwitchTaskInfo : public TaskInfo { | |||||
| public: | |||||
| LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, | |||||
| const std::vector<uint32_t> &label_list, void *cond) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), | |||||
| label_size_(label_size), | |||||
| label_list_(label_list), | |||||
| cond_(cond) {} | |||||
| ~LabelSwitchTaskInfo() override {} | |||||
| uint32_t label_size() const { return label_size_; } | |||||
| const std::vector<uint32_t> &label_list() const { return label_list_; } | |||||
| void *cond() const { return cond_; } | |||||
| private: | |||||
| uint32_t label_size_; | |||||
| std::vector<uint32_t> label_list_; | |||||
| void *cond_; | |||||
| }; | |||||
| class EventTaskInfo : public TaskInfo { | |||||
| public: | |||||
| uint32_t event_id() const { return event_id_; } | |||||
| protected: | |||||
| EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||||
| : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} | |||||
| ~EventTaskInfo() override {} | |||||
| uint32_t event_id_; | |||||
| }; | |||||
| class EventRecordTaskInfo : public EventTaskInfo { | |||||
| public: | |||||
| EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||||
| ~EventRecordTaskInfo() override {} | |||||
| }; | |||||
| class EventWaitTaskInfo : public EventTaskInfo { | |||||
| public: | |||||
| EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||||
| ~EventWaitTaskInfo() override {} | |||||
| }; | |||||
| class FusionStartTaskInfo : public TaskInfo { | |||||
| public: | |||||
| explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} | |||||
| ~FusionStartTaskInfo() override {} | |||||
| }; | |||||
| class FusionEndTaskInfo : public TaskInfo { | |||||
| public: | |||||
| explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} | |||||
| ~FusionEndTaskInfo() override {} | |||||
| }; | |||||
| class HcclTaskInfo : public TaskInfo { | |||||
| public: | |||||
| HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, | |||||
| void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||||
| const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, | |||||
| int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), | |||||
| hccl_type_(hccl_type), | |||||
| input_data_addr_(input_data_addr), | |||||
| output_data_addr_(output_data_addr), | |||||
| workspace_size_(workspace_size), | |||||
| hccl_stream_num_(hccl_stream_num), | |||||
| private_def_(private_def), | |||||
| ops_kernel_store_(ops_kernel_store), | |||||
| count_(count), | |||||
| root_id_(root_id), | |||||
| op_type_(op_type), | |||||
| data_type_(data_type), | |||||
| group_(group) {} | |||||
| ~HcclTaskInfo() override {} | |||||
| const std::string &hccl_type() const { return hccl_type_; } | |||||
| void *input_data_addr() const { return input_data_addr_; } | |||||
| void *output_data_addr() const { return output_data_addr_; } | |||||
| int64_t workspace_size() const { return workspace_size_; } | |||||
| int64_t hccl_stream_num() const { return hccl_stream_num_; } | |||||
| const std::vector<uint8_t> &private_def() const { return private_def_; } | |||||
| void *ops_kernel_store() const { return ops_kernel_store_; } | |||||
| int32_t count() const { return count_; } | |||||
| int64_t root_id() const { return root_id_; } | |||||
| int64_t op_type() const { return op_type_; } | |||||
| int64_t data_type() const { return data_type_; } | |||||
| const std::string &group() const { return group_; } | |||||
| private: | |||||
| std::string hccl_type_; | |||||
| void *input_data_addr_; | |||||
| void *output_data_addr_; | |||||
| int64_t workspace_size_; | |||||
| int64_t hccl_stream_num_; | |||||
| std::vector<uint8_t> private_def_; | |||||
| void *ops_kernel_store_; | |||||
| int32_t count_; | |||||
| int64_t root_id_; | |||||
| int64_t op_type_; | |||||
| int64_t data_type_; | |||||
| std::string group_; | |||||
| }; | |||||
| class ProfilerTraceTaskInfo : public TaskInfo { | |||||
| public: | |||||
| ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), | |||||
| log_id_(log_id), | |||||
| notify_(notify), | |||||
| flat_(flat) {} | |||||
| ~ProfilerTraceTaskInfo() override {} | |||||
| uint64_t log_id() const { return log_id_; } | |||||
| bool notify() const { return notify_; } | |||||
| uint32_t flat() const { return flat_; } | |||||
| private: | |||||
| uint64_t log_id_; | |||||
| bool notify_; | |||||
| uint32_t flat_; | |||||
| }; | |||||
| class MemcpyAsyncTaskInfo : public TaskInfo { | |||||
| public: | |||||
| MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, | |||||
| uint64_t count, uint32_t kind, bool dump_flag) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), | |||||
| dst_(dst), | |||||
| dst_max_(dst_max), | |||||
| src_(src), | |||||
| count_(count), | |||||
| kind_(kind) {} | |||||
| ~MemcpyAsyncTaskInfo() override {} | |||||
| void *dst() const { return dst_; } | |||||
| uint64_t dst_max() const { return dst_max_; } | |||||
| void *src() const { return src_; } | |||||
| uint64_t count() const { return count_; } | |||||
| uint32_t kind() const { return kind_; } | |||||
| private: | |||||
| void *dst_; | |||||
| uint64_t dst_max_; | |||||
| void *src_; | |||||
| uint64_t count_; | |||||
| int32_t kind_; | |||||
| }; | |||||
| class StreamSwitchTaskInfo : public TaskInfo { | |||||
| public: | |||||
| StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, | |||||
| void *value_addr, int64_t cond, int64_t data_type) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), | |||||
| true_stream_id_(true_stream_id), | |||||
| input_addr_(input_addr), | |||||
| value_addr_(value_addr), | |||||
| cond_(cond), | |||||
| data_type_(data_type) {} | |||||
| ~StreamSwitchTaskInfo() override {} | |||||
| int64_t true_stream_id() const { return true_stream_id_; } | |||||
| void *input_addr() const { return input_addr_; } | |||||
| void *value_addr() const { return value_addr_; } | |||||
| int64_t cond() const { return cond_; } | |||||
| int64_t data_type() const { return data_type_; } | |||||
| private: | |||||
| int64_t true_stream_id_; | |||||
| void *input_addr_; | |||||
| void *value_addr_; | |||||
| int64_t cond_; | |||||
| int64_t data_type_; | |||||
| }; | |||||
| class StreamActiveTaskInfo : public TaskInfo { | |||||
| public: | |||||
| StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) | |||||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} | |||||
| ~StreamActiveTaskInfo() override {} | |||||
| uint32_t active_stream_id() const { return active_stream_id_; } | |||||
| private: | |||||
| uint32_t active_stream_id_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ | |||||