| @@ -6,7 +6,6 @@ if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | |||
| add_subdirectory(offline) | |||
| elseif (ENABLE_D) | |||
| add_subdirectory(common) | |||
| add_subdirectory(ge_runtime) | |||
| endif () | |||
| set(GRAPHENGINE_PROTO_LIST | |||
| @@ -1,78 +0,0 @@ | |||
| ############ libge_runtime.so ############ | |||
| set(GE_SRC_LIST | |||
| "model_runner.cc" | |||
| "runtime_model.cc" | |||
| "output.cc" | |||
| "task/aicpu_task.cc" | |||
| "task/cce_task.cc" | |||
| "task/tbe_task.cc" | |||
| "task/event_record_task.cc" | |||
| "task/event_wait_task.cc" | |||
| "task/stream_active_task.cc" | |||
| "task/stream_switch_task.cc" | |||
| "task/hccl_task.cc" | |||
| "task/memcpy_async_task.cc" | |||
| "task/profiler_task.cc" | |||
| "task/label_goto_task.cc" | |||
| "task/label_set_task.cc" | |||
| "task/label_switch_task.cc" | |||
| ) | |||
| add_library(ge_runtime SHARED ${GE_SRC_LIST}) | |||
| target_compile_options(ge_runtime PRIVATE | |||
| -Werror | |||
| -O2 | |||
| -Wno-deprecated-declarations | |||
| -fno-common | |||
| ) | |||
| target_compile_definitions(ge_runtime PRIVATE | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| LOG_CPP | |||
| ) | |||
| target_include_directories(ge_runtime PRIVATE | |||
| ${CMAKE_CURRENT_LIST_DIR} | |||
| ${GE_CODE_DIR} | |||
| ${GE_CODE_DIR}/ge | |||
| ${GE_CODE_DIR}/inc | |||
| ${GE_CODE_DIR}/inc/graph | |||
| ${GE_CODE_DIR}/inc/external | |||
| ${GE_CODE_DIR}/inc/framework | |||
| ${GE_CODE_DIR}/inc/framework/common | |||
| ${GE_CODE_DIR}/inc/framework/ge_runtime | |||
| ${GE_CODE_DIR}/inc/cce | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
| ${METADEF_DIR} | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/external/graph | |||
| ${METADEF_DIR}/inc/external | |||
| ${METADEF_DIR}/inc/graph | |||
| ${CMAKE_BINARY_DIR} | |||
| ${CMAKE_BINARY_DIR}/proto/ge | |||
| ) | |||
| target_link_options(ge_runtime PRIVATE | |||
| -Wl,-Bsymbolic | |||
| ) | |||
| target_link_libraries(ge_runtime PRIVATE | |||
| $<BUILD_INTERFACE:intf_pub> | |||
| -Wl,--no-as-needed | |||
| slog | |||
| runtime | |||
| c_sec | |||
| graph | |||
| -Wl,--as-needed | |||
| -lrt | |||
| -ldl | |||
| ) | |||
| ############ install ############ | |||
| set(INSTALL_BASE_DIR "") | |||
| set(INSTALL_LIBRARY_DIR lib) | |||
| install(TARGETS ge_runtime OPTIONAL | |||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||
| ) | |||
| @@ -1,62 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_MODEL_CONTEXT_H_ | |||
| #define GE_GE_RUNTIME_MODEL_CONTEXT_H_ | |||
| #include <vector> | |||
| #include "runtime/rt_model.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class ModelContext { | |||
| public: | |||
| ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, | |||
| rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, | |||
| const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) | |||
| : device_id_(device_id), | |||
| session_id_(session_id), | |||
| priority_(priority), | |||
| rt_model_handle_(rt_model_handle), | |||
| rt_model_stream_(rt_model_stream), | |||
| stream_list_(stream_list), | |||
| label_list_(label_list), | |||
| event_list_(event_list) {} | |||
| ~ModelContext() {} | |||
| uint64_t device_id() const { return device_id_; } | |||
| uint64_t session_id() const { return session_id_; } | |||
| int32_t priority() const { return priority_; } | |||
| const rtModel_t &rt_model_handle() const { return rt_model_handle_; } | |||
| const rtStream_t &rt_model_stream() const { return rt_model_stream_; } | |||
| const std::vector<rtStream_t> &stream_list() const { return stream_list_; } | |||
| const std::vector<rtLabel_t> &label_list() const { return label_list_; } | |||
| const std::vector<rtEvent_t> &event_list() const { return event_list_; } | |||
| private: | |||
| uint32_t device_id_; | |||
| uint64_t session_id_; | |||
| int32_t priority_; | |||
| rtModel_t rt_model_handle_; | |||
| rtStream_t rt_model_stream_; | |||
| std::vector<rtStream_t> stream_list_; | |||
| std::vector<rtLabel_t> label_list_; | |||
| std::vector<rtEvent_t> event_list_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_MODEL_CONTEXT_H_ | |||
| @@ -1,173 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/model_runner.h" | |||
| #include "./runtime_model.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "common/ge/ge_util.h" | |||
| #include "ge_runtime/davinci_model.h" | |||
| #include "graph/op_desc.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| using RuntimeModelPtr = std::shared_ptr<RuntimeModel>; | |||
| using DavinciModelPtr = std::shared_ptr<DavinciModel>; | |||
| ModelRunner &ModelRunner::Instance() { | |||
| static ModelRunner instance; // Guaranteed to be destroyed. | |||
| return instance; | |||
| } | |||
| bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||
| std::shared_ptr<DavinciModel> davinci_model, | |||
| std::shared_ptr<ModelListener> listener) { | |||
| std::shared_ptr<RuntimeModel> model = MakeShared<RuntimeModel>(); | |||
| if (model == nullptr) { | |||
| return false; | |||
| } | |||
| bool status = model->Load(device_id, session_id, davinci_model); | |||
| if (!status) { | |||
| return false; | |||
| } | |||
| runtime_models_[model_id] = model; | |||
| return true; | |||
| } | |||
| bool ModelRunner::DistributeTask(uint32_t model_id) { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||
| return false; | |||
| } | |||
| return model_iter->second->DistributeTask(); | |||
| } | |||
| bool ModelRunner::LoadModelComplete(uint32_t model_id) { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||
| return false; | |||
| } | |||
| return model_iter->second->LoadComplete(); | |||
| } | |||
| const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||
| static const std::vector<uint32_t> empty_ret; | |||
| return empty_ret; | |||
| } | |||
| return model_iter->second->GetTaskIdList(); | |||
| } | |||
| const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||
| static const std::vector<uint32_t> empty_ret; | |||
| return empty_ret; | |||
| } | |||
| return model_iter->second->GetStreamIdList(); | |||
| } | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGW("Model id %u not found.", model_id); | |||
| static const std::map<std::string, std::shared_ptr<RuntimeInfo>> empty_ret; | |||
| return empty_ret; | |||
| } | |||
| return model_iter->second->GetRuntimeInfoMap(); | |||
| } | |||
| void *ModelRunner::GetModelHandle(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGW("Model id %u not found.", model_id); | |||
| return nullptr; | |||
| } | |||
| return model_iter->second->GetModelHandle(); | |||
| } | |||
| bool ModelRunner::UnloadModel(uint32_t model_id) { | |||
| auto iter = runtime_models_.find(model_id); | |||
| if (iter != runtime_models_.end()) { | |||
| (void)runtime_models_.erase(iter); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| bool ModelRunner::RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data) { | |||
| if (output_data == nullptr) { | |||
| GELOGW("Output data point is null."); | |||
| } | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||
| return false; | |||
| } | |||
| bool status = model_iter->second->CopyInputData(input_data); | |||
| if (!status) { | |||
| GELOGE(FAILED, "Copy input data fail."); | |||
| return false; | |||
| } | |||
| status = model_iter->second->Run(); | |||
| if (!status) { | |||
| GELOGE(FAILED, "Run model fail."); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool ModelRunner::GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, | |||
| std::vector<InputOutputDescInfo> *input_desc, | |||
| std::vector<InputOutputDescInfo> *output_desc, | |||
| std::vector<uint32_t> *input_format, std::vector<uint32_t> *output_format) { | |||
| if (runtime_models_.find(model_id) == runtime_models_.end()) { | |||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||
| return false; | |||
| } | |||
| auto model = runtime_models_[model_id]; | |||
| if (input_desc == nullptr || output_desc == nullptr) { | |||
| GELOGE(PARAM_INVALID, "input_desc or output_desc is null."); | |||
| return false; | |||
| } | |||
| bool status = model->GetInputOutputDescInfo(zero_copy, input_desc, output_desc, input_format, output_format); | |||
| if (!status) { | |||
| GELOGE(FAILED, "Get input output desc info fail."); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,66 +0,0 @@ | |||
| LOCAL_PATH := $(call my-dir) | |||
| # task.proto is old task, add it for ops_kernel_info_store | |||
| local_ge_runtime_src_files := \ | |||
| model_runner.cc \ | |||
| runtime_model.cc \ | |||
| output.cc \ | |||
| task/aicpu_task.cc \ | |||
| task/cce_task.cc \ | |||
| task/tbe_task.cc \ | |||
| task/event_record_task.cc \ | |||
| task/event_wait_task.cc \ | |||
| task/stream_active_task.cc \ | |||
| task/stream_switch_task.cc \ | |||
| task/hccl_task.cc \ | |||
| task/memcpy_async_task.cc \ | |||
| task/profiler_task.cc \ | |||
| local_ge_runtime_include := \ | |||
| $(LOCAL_PATH)/ \ | |||
| $(TOPDIR)libc_sec/include \ | |||
| $(TOPDIR)inc/external \ | |||
| $(TOPDIR)inc/external/graph \ | |||
| $(TOPDIR)inc/framework \ | |||
| $(TOPDIR)inc/graph \ | |||
| $(TOPDIR)inc \ | |||
| $(LOCAL_PATH)/../ \ | |||
| third_party/protobuf/include | |||
| local_ge_runtime_shared_library := \ | |||
| libruntime \ | |||
| libslog \ | |||
| libc_sec | |||
| local_ge_runtime_ldflags := -lrt -ldl | |||
| # compile device libge_runtime | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_runtime | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 | |||
| LOCAL_CFLAGS += -Werror | |||
| LOCAL_SRC_FILES := $(local_ge_runtime_src_files) | |||
| LOCAL_C_INCLUDES := $(local_ge_runtime_include) | |||
| LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library) | |||
| LOCAL_LDFLAGS += $(local_ge_runtime_ldflags) | |||
| include $(BUILD_SHARED_LIBRARY) | |||
| # compile host libge_runtime | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_runtime | |||
| LOCAL_CFLAGS += -Werror | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| else | |||
| LOCAL_CFLAGS += -O2 | |||
| endif | |||
| LOCAL_SRC_FILES := $(local_ge_runtime_src_files) | |||
| LOCAL_C_INCLUDES := $(local_ge_runtime_include) | |||
| LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library) | |||
| LOCAL_LDFLAGS += $(local_ge_runtime_ldflags) | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| @@ -1,94 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/output.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "common/util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| Output::Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model) | |||
| : model_(model), op_info_(op_info), input_num_(0) {} | |||
| Output::~Output() {} | |||
| bool Output::Init() { | |||
| if (op_info_ == nullptr || model_ == nullptr) { | |||
| GELOGE(INTERNAL_ERROR, "The op_desc_ or model_ is nullptr."); | |||
| return false; | |||
| } | |||
| input_num_ = op_info_->input_tensors.size(); | |||
| v_input_size_.clear(); | |||
| v_input_data_addr_.clear(); | |||
| auto input_vector = op_info_->input_addrs; | |||
| if (input_num_ != input_vector.size()) { | |||
| GELOGE(INTERNAL_ERROR, "The input desc size: %zu != input addr size: %zu.", input_num_, input_vector.size()); | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < input_num_; i++) { | |||
| uint32_t tensorSize = 0; | |||
| const auto &input_info = op_info_->input_tensors.at(i); | |||
| tensorSize = input_info.size; | |||
| v_input_size_.push_back(tensorSize); | |||
| v_input_data_addr_.push_back(reinterpret_cast<uint8_t *>(input_vector.at(i))); | |||
| } | |||
| GELOGI("Init output:%zu, %zu, %zu", input_num_, v_input_size_.size(), v_input_data_addr_.size()); | |||
| return true; | |||
| } | |||
| /// | |||
| /// @ingroup domi_ome | |||
| /// @brief Copy Op Output to user space. | |||
| /// @brief when model running, Add one DataOp as input node, Add one Output Op as output node. | |||
| /// @return Status | |||
| /// | |||
| bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share) { | |||
| if (rslt == nullptr) { | |||
| GELOGE(FAILED, "OutputData is null."); | |||
| return false; | |||
| } | |||
| uint32_t data_count = 0; | |||
| if (v_input_size_.empty() || v_input_data_addr_.empty()) { | |||
| GELOGE(INTERNAL_ERROR, "v_output_size_ or v_output_data_addr_ is empty!"); | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < input_num_; i++) { | |||
| DataBuffer data_buf = rslt->blobs[data_begin + data_count]; | |||
| bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]); | |||
| return ret; | |||
| } | |||
| data_index = data_begin + data_count; | |||
| } | |||
| return true; | |||
| } | |||
| bool Output::SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, | |||
| bool support_mem_share) { | |||
| return true; | |||
| } | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,53 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_OUTPUT_H_ | |||
| #define GE_GE_RUNTIME_OUTPUT_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ge_runtime/davinci_model.h" | |||
| #include "common/ge_types.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class Output { | |||
| public: | |||
| Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model); | |||
| virtual ~Output(); | |||
| bool Init(); | |||
| bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); | |||
| bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, bool support_mem_share); | |||
| // Copy assignment operator and copy constructor are deleted | |||
| Output &operator=(const Output &output) = delete; | |||
| Output(const Output &output) = delete; | |||
| protected: | |||
| std::shared_ptr<DavinciModel> model_; | |||
| OpInfoPtr op_info_; | |||
| // Input descriptions | |||
| size_t input_num_; | |||
| vector<void *> v_input_data_addr_; // Init as:buf_base + op_def_->input(i)); | |||
| vector<uint32_t> v_input_size_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_OUTPUT_H_ | |||
| @@ -1,27 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: task.proto | |||
| #ifndef STUB_TASK_PROTO_H | |||
| #define STUB_TASK_PROTO_H | |||
| namespace domi { | |||
| class TaskDef; | |||
| } | |||
| #endif // STUB_TASK_PROTO_H | |||
| @@ -1,547 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/runtime_model.h" | |||
| #include <set> | |||
| #include "./model_context.h" | |||
| #include "./task/task.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "common/types.h" | |||
| #include "common/util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "framework/common/op/op_parser_util.h" | |||
| #include "graph/types.h" | |||
| #include "task/task_factory.h" | |||
| #include "ge/common/math/math_util.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| namespace { | |||
| const int kOffsetUnit = 8; | |||
| const uint32_t kStringHeadElems = 2; | |||
| } // namespace | |||
| RuntimeModel::~RuntimeModel() { | |||
| GELOGI("RuntimeModel destructor start"); | |||
| // Unbind rtModel from all task related streams | |||
| RtModelUnbindStream(); | |||
| // Release task first, hccl task hold stream | |||
| task_list_.clear(); | |||
| // Release all task related streams | |||
| RtStreamDestory(); | |||
| // Release rtlabel resource | |||
| RtLabelDestory(); | |||
| // Release rtEvent resourece | |||
| RtEventDestory(); | |||
| GELOGI("Do RtModelDestory"); | |||
| // Release all rt_model | |||
| RtModelDestory(); | |||
| } | |||
| bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| if (davinci_model == nullptr) { | |||
| GELOGE(PARAM_INVALID, "Davinci model is null."); | |||
| return false; | |||
| } | |||
| std::set<int64_t> wait_active_streams; | |||
| std::set<int64_t> force_copy_streams; | |||
| for (const auto &stream_id : davinci_model->GetWaitActiveStreams()) { | |||
| GELOGI("stream id %u is wait active stream.", stream_id); | |||
| (void)wait_active_streams.insert(stream_id); | |||
| } | |||
| for (const auto &stream_id : davinci_model->GetForceCopyStreams()) { | |||
| GELOGI("stream id %u is force copy stream.", stream_id); | |||
| (void)force_copy_streams.insert(stream_id); | |||
| } | |||
| GELOGI("stream number:%u", davinci_model->GetStreamNum()); | |||
| for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { | |||
| rtStream_t stream = nullptr; | |||
| uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) | |||
| ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||
| : (RT_STREAM_PERSISTENT); | |||
| rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtStreamCreate failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("rtStreamCreateWithFlags end."); | |||
| stream_list_.emplace_back(stream); | |||
| // Bind rt_model_handle_ to all task related streams | |||
| flag = (wait_active_streams.find(i) != wait_active_streams.end()) ? (static_cast<uint32_t>(RT_INVALID_FLAG)) | |||
| : (static_cast<uint32_t>(RT_HEAD_STREAM)); | |||
| rt_ret = rtModelBindStream(rt_model_handle_, stream, flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtModelBindStream failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("stream index:%u, stream:%p.", i, stream); | |||
| } | |||
| return true; | |||
| } | |||
| bool RuntimeModel::InitEvent(uint32_t event_num) { | |||
| GELOGI("event number:%u.", event_num); | |||
| for (uint32_t i = 0; i < event_num; ++i) { | |||
| rtEvent_t rt_event; | |||
| rtError_t rt_ret = rtEventCreate(&rt_event); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtEventCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||
| return false; | |||
| } | |||
| event_list_.push_back(rt_event); | |||
| } | |||
| return true; | |||
| } | |||
| bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||
| label_list_.resize(davinci_model->GetBatchNum()); | |||
| for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||
| if (task_info == nullptr) { | |||
| GELOGE(PARAM_INVALID, "task_info is null."); | |||
| continue; | |||
| } | |||
| if (task_info->type() != TaskInfoType::LABEL_SET) { | |||
| continue; | |||
| } | |||
| auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||
| if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||
| GELOGE(PARAM_INVALID, "Invalid stream id."); | |||
| return false; | |||
| } | |||
| rtLabel_t rt_label = nullptr; | |||
| rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| label_list_[label_set_task_info->label_id()] = rt_label; | |||
| } | |||
| return true; | |||
| } | |||
| bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| GELOGI("InitResource start"); | |||
| if (davinci_model == nullptr) { | |||
| GELOGE(PARAM_INVALID, "davinci model is null"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtModelCreate(&rt_model_handle_, 0); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtModelCreate failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| // Create rtStream for rt_model_handle_ | |||
| rt_ret = rtStreamCreate(&rt_model_stream_, davinci_model->GetPriority()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtStreamCreate failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("rtStreamCreate end"); | |||
| if (!InitStream(davinci_model)) { | |||
| return false; | |||
| } | |||
| if (!InitEvent(davinci_model->GetEventNum())) { | |||
| return false; | |||
| } | |||
| if (!InitLabel(davinci_model)) { | |||
| return false; | |||
| } | |||
| GELOGI("InitResource succ"); | |||
| return true; | |||
| } | |||
| void RuntimeModel::GenerateTask(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model) { | |||
| GELOGI("GenerateTask start."); | |||
| if (davinci_model == nullptr) { | |||
| GELOGE(PARAM_INVALID, "davinci model is null"); | |||
| return; | |||
| } | |||
| auto task_infos = davinci_model->GetTaskInfoList(); | |||
| ModelContext model_context(device_id, session_id, davinci_model->GetPriority(), rt_model_handle_, rt_model_stream_, | |||
| stream_list_, label_list_, event_list_); | |||
| for (auto &task_info : task_infos) { | |||
| auto task = TaskFactory::GetInstance().Create(model_context, task_info); | |||
| task_list_.push_back(task); | |||
| } | |||
| GELOGI("GenerateTask succ."); | |||
| } | |||
| bool RuntimeModel::LoadTask() { | |||
| GELOGI("LoadTask start."); | |||
| for (auto &task : task_list_) { | |||
| if (task == nullptr) { | |||
| GELOGE(PARAM_INVALID, "task is null."); | |||
| continue; | |||
| } | |||
| bool ret = task->Distribute(); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "task distribute fail."); | |||
| return false; | |||
| } | |||
| uint32_t task_id = 0; | |||
| uint32_t stream_id = 0; | |||
| rtError_t rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| task_id_list_.push_back(task_id); | |||
| stream_id_list_.push_back(stream_id); | |||
| if (task->Args() != nullptr) { | |||
| std::shared_ptr<RuntimeInfo> runtime_tuple = nullptr; | |||
| GE_MAKE_SHARED(runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args()), return false); | |||
| auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple); | |||
| if (!emplace_ret.second) { | |||
| GELOGW("Task name exist:%s", task->task_name().c_str()); | |||
| } | |||
| } | |||
| } | |||
| if (task_list_.empty()) { | |||
| GELOGE(FAILED, "Task list is empty"); | |||
| return false; | |||
| } | |||
| GELOGI("LoadTask succ."); | |||
| return true; | |||
| } | |||
| bool RuntimeModel::LoadComplete() { | |||
| uint32_t task_id = 0; | |||
| uint32_t stream_id = 0; | |||
| auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret); | |||
| return RT_FAILED; | |||
| } | |||
| task_id_list_.push_back(task_id); | |||
| stream_id_list_.push_back(stream_id); | |||
| rt_ret = rtModelLoadComplete(rt_model_handle_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model) { | |||
| bool status = InitResource(davinci_model); | |||
| if (!status) { | |||
| GELOGE(FAILED, "InitResource failed."); | |||
| return status; | |||
| } | |||
| status = InitDataInfo(davinci_model); | |||
| if (!status) { | |||
| GELOGE(FAILED, "InitDataInfo failed."); | |||
| return status; | |||
| } | |||
| status = InitOutputInfo(davinci_model); | |||
| if (!status) { | |||
| GELOGE(FAILED, "InitOutputInfo failed."); | |||
| return status; | |||
| } | |||
| status = InitConstantInfo(davinci_model); | |||
| if (!status) { | |||
| GELOGE(FAILED, "InitConstantInfo failed."); | |||
| return status; | |||
| } | |||
| GenerateTask(device_id, session_id, davinci_model); | |||
| return status; | |||
| } | |||
| bool RuntimeModel::DistributeTask() { | |||
| bool status = LoadTask(); | |||
| if (!status) { | |||
| GELOGE(FAILED, "DistributeTask failed"); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool RuntimeModel::Run() { | |||
| GELOGI("Davinci task run start"); | |||
| rtError_t ret = rtModelExecute(rt_model_handle_, rt_model_stream_, 0); | |||
| if (ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Model execute failed, ret = 0x%X", ret); | |||
| return false; | |||
| } | |||
| GELOGI("Run rtModelExecute success, ret = 0x%X", ret); | |||
| ret = rtStreamSynchronize(rt_model_stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| if (ret == ACL_ERROR_RT_END_OF_SEQUENCE) { | |||
| GELOGI("Model stream ACL_ERROR_RT_END_OF_SEQUENCE signal received, ret = 0x%X", ret); | |||
| return true; | |||
| } | |||
| GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | |||
| return false; | |||
| } | |||
| GELOGI("Davinci task run succ."); | |||
| return true; | |||
| } | |||
| void RuntimeModel::RtModelUnbindStream() noexcept { | |||
| for (size_t i = 0; i < stream_list_.size(); i++) { | |||
| if (rtModelUnbindStream(rt_model_handle_, stream_list_[i]) != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Unbind stream from model failed! Index: %zu", i); | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| void RuntimeModel::RtStreamDestory() noexcept { | |||
| if (rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Destroy stream for rt_model failed!"); | |||
| return; | |||
| } | |||
| for (size_t i = 0; i < stream_list_.size(); i++) { | |||
| if (rtStreamDestroy(stream_list_[i]) != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Destroy stream failed! Index: %zu", i); | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| void RuntimeModel::RtLabelDestory() noexcept { | |||
| for (size_t i = 0; i < label_list_.size(); i++) { | |||
| if (label_list_[i] == nullptr) { | |||
| continue; | |||
| } | |||
| if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| void RuntimeModel::RtModelDestory() noexcept { | |||
| rtError_t ret = rtModelDestroy(rt_model_handle_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||
| return; | |||
| } | |||
| } | |||
| void RuntimeModel::RtEventDestory() noexcept { | |||
| for (size_t i = 0; i < event_list_.size(); i++) { | |||
| if (rtEventDestroy(event_list_[i]) != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Destroy event failed! Index: %zu", i); | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| bool RuntimeModel::InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model) { return true; } | |||
| bool RuntimeModel::InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| if (davinci_model == nullptr) { | |||
| GELOGE(PARAM_INVALID, "davinci model is null"); | |||
| return false; | |||
| } | |||
| output_info_list_ = davinci_model->GetOutputInfoList(); | |||
| return true; | |||
| } | |||
| bool RuntimeModel::CopyInputData(const InputData &input_data) { | |||
| if (input_data.blobs.size() != data_info_list_.size()) { | |||
| GELOGE(PARAM_INVALID, "The input data list size (%zu) does not match the model input list size (%zu)", | |||
| input_data.blobs.size(), data_info_list_.size()); | |||
| return false; | |||
| } | |||
| for (const auto &data_info : data_info_list_) { | |||
| if (data_info == nullptr) { | |||
| GELOGE(PARAM_INVALID, "data info is null."); | |||
| return false; | |||
| } | |||
| bool ret = CopyInputDataToModel(input_data.blobs, data_info); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "Copy input data to model ret fail, data_info: %s, model id: %u", data_info->name.c_str(), | |||
| input_data.model_id); | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool RuntimeModel::CopyInputDataToModel(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) { | |||
| return true; | |||
| } | |||
| bool RuntimeModel::CopyHostData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) const { | |||
| GELOGI("Start CopyHostData."); | |||
| if (data.empty()) { | |||
| GELOGE(PARAM_INVALID, "data buffer is empty."); | |||
| return false; | |||
| } | |||
| if (data_info == nullptr) { | |||
| GELOGE(PARAM_INVALID, "data info is null."); | |||
| return false; | |||
| } | |||
| void *host_data_addr = data[data_info->index].data; | |||
| uint32_t copy_size = data[data_info->index].length; | |||
| GELOGD("data output tensor is aipp tensor,copy data only."); | |||
| const std::vector<uintptr_t> &outputs = data_info->output_addrs; | |||
| if (outputs.empty()) { | |||
| GELOGE(PARAM_INVALID, "Output addrs is empty."); | |||
| return false; | |||
| } | |||
| // Copy input data to data nodes | |||
| void *data_out_addr = reinterpret_cast<void *>(outputs[0]); | |||
| rtError_t rt_ret = rtMemcpy(data_out_addr, copy_size, host_data_addr, copy_size, RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool RuntimeModel::CopyTransData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) { | |||
| return true; | |||
| } | |||
| bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| // Const no input, only 1 output, and this output has no data | |||
| // weight data copy to output mem | |||
| if (davinci_model == nullptr) { | |||
| GELOGE(PARAM_INVALID, "Davinci model is null."); | |||
| return false; | |||
| } | |||
| constant_info_list_ = davinci_model->GetConstantInfoList(); | |||
| for (const auto &constant : constant_info_list_) { | |||
| if (constant == nullptr) { | |||
| GELOGE(PARAM_INVALID, "constant is null"); | |||
| continue; | |||
| } | |||
| if (constant->output_tensors.empty()) { | |||
| GELOGE(PARAM_INVALID, "Output tensors is empty"); | |||
| return false; | |||
| } | |||
| if (constant->weight_tensors.empty()) { | |||
| GELOGE(PARAM_INVALID, "Weight tensors is empty"); | |||
| return false; | |||
| } | |||
| if (constant->output_tensors[0].size < constant->weight_data.size()) { | |||
| GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size, | |||
| constant->weight_data.size()); | |||
| return false; | |||
| } | |||
| if (constant->weight_data.empty()) { | |||
| GELOGW("Const op:%s has no weight data.", constant->name.c_str()); | |||
| continue; | |||
| } | |||
| if (constant->weight_tensors[0].datatype == DT_STRING) { | |||
| /// If tensor is a scaler, it's shape size if zero, according ge_tensor.cc. | |||
| /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | |||
| /// and that of unknown shape is zero too. | |||
| /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | |||
| int64_t elem_num = | |||
| (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||
| if (constant->weight_data.size() < sizeof(uint64_t)) { | |||
| GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | |||
| return false; | |||
| } | |||
| uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); | |||
| uint32_t head_len = kOffsetUnit * kStringHeadElems; | |||
| if (CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { | |||
| GELOGE(FAILED, "Shape size is invalid"); | |||
| return false; | |||
| } | |||
| int64_t offset = elem_num * head_len; | |||
| uintptr_t hbm_raw_data_base_addr = reinterpret_cast<uintptr_t>(constant->output_addrs[0]) + offset; | |||
| for (int64_t i = elem_num - 1; i >= 0; --i) { | |||
| buff[i * kStringHeadElems] = hbm_raw_data_base_addr + (buff[i * kStringHeadElems] - buff[0]); | |||
| } | |||
| } | |||
| rtError_t rt_ret = rtMemcpy(reinterpret_cast<void *>(constant->output_addrs[0]), constant->output_tensors[0].size, | |||
| constant->weight_data.data(), constant->weight_data.size(), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtGetFunctionByName failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool RuntimeModel::GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||
| std::vector<InputOutputDescInfo> *output_desc, | |||
| std::vector<uint32_t> *input_format, std::vector<uint32_t> *output_format) { | |||
| return true; | |||
| } | |||
| bool RuntimeModel::GetInputDescInfo(std::vector<InputOutputDescInfo> *input_desc, std::vector<uint32_t> *formats) { | |||
| return true; | |||
| } | |||
| bool RuntimeModel::GetOutputDescInfo(std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *formats) { | |||
| return true; | |||
| } | |||
| void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutputDescInfo *output, | |||
| uint32_t *format_result) {} | |||
| const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } | |||
| const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,92 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_RUNTIME_MODEL_H_ | |||
| #define GE_GE_RUNTIME_RUNTIME_MODEL_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ge_runtime/davinci_model.h" | |||
| #include "common/ge_types.h" | |||
| #include "runtime/base.h" | |||
| #include "runtime/rt_model.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||
| class Task; | |||
| class RuntimeModel { | |||
| public: | |||
| RuntimeModel() = default; | |||
| ~RuntimeModel(); | |||
| bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool DistributeTask(); | |||
| bool LoadComplete(); | |||
| const std::vector<uint32_t> &GetTaskIdList() const; | |||
| const std::vector<uint32_t> &GetStreamIdList() const; | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | |||
| rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||
| bool Run(); | |||
| bool CopyInputData(const InputData &input_data); | |||
| bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||
| std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||
| std::vector<uint32_t> *output_format); | |||
| private: | |||
| bool InitResource(std::shared_ptr<DavinciModel> &davinci_model); | |||
| void GenerateTask(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool LoadTask(); | |||
| bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitEvent(uint32_t event_num); | |||
| bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| void RtModelUnbindStream() noexcept; | |||
| void RtStreamDestory() noexcept; | |||
| void RtModelDestory() noexcept; | |||
| void RtLabelDestory() noexcept; | |||
| void RtEventDestory() noexcept; | |||
| bool CopyInputDataToModel(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info); | |||
| bool CopyHostData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) const; | |||
| bool CopyTransData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info); | |||
| bool GetInputDescInfo(std::vector<InputOutputDescInfo> *input_desc, std::vector<uint32_t> *formats); | |||
| bool GetOutputDescInfo(std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *formats); | |||
| void CreateOutput(uint32_t index, const OpInfo &op_info, InputOutputDescInfo *output, uint32_t *format); | |||
| rtModel_t rt_model_handle_{}; | |||
| rtStream_t rt_model_stream_{}; | |||
| std::vector<rtStream_t> stream_list_{}; | |||
| std::vector<rtLabel_t> label_list_{}; | |||
| std::vector<rtEvent_t> event_list_{}; | |||
| std::vector<std::shared_ptr<Task>> task_list_{}; | |||
| std::vector<std::shared_ptr<OpInfo>> data_info_list_{}; | |||
| std::vector<std::shared_ptr<OpInfo>> output_info_list_{}; | |||
| std::vector<std::shared_ptr<OpInfo>> constant_info_list_{}; | |||
| std::vector<uint32_t> task_id_list_{}; | |||
| std::vector<uint32_t> stream_id_list_{}; | |||
| std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_RUNTIME_MODEL_H_ | |||
| @@ -1,168 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/aicpu_task.h" | |||
| #include <vector> | |||
| #include "ge_runtime/task/task_factory.h" | |||
| #include "aicpu/common/aicpu_task_struct.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info) | |||
| : TaskRepeater<AicpuTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| args_(nullptr), | |||
| ext_info_(nullptr), | |||
| input_output_addr_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| if (stream_list.size() == 1) { | |||
| stream_ = stream_list[0]; | |||
| } else if (stream_list.size() > task_info->stream_id()) { | |||
| stream_ = stream_list[task_info->stream_id()]; | |||
| } else { | |||
| GELOGW("index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); | |||
| } | |||
| } | |||
| AicpuTask::~AicpuTask() { | |||
| ReleaseRtMem(&args_); | |||
| ReleaseRtMem(&ext_info_); | |||
| } | |||
| bool AicpuTask::Distribute() { | |||
| GELOGI("InitAicpuTask start."); | |||
| vector<void *> io_addrs; | |||
| io_addrs.insert(io_addrs.end(), task_info_->input_data_addrs().begin(), task_info_->input_data_addrs().end()); | |||
| io_addrs.insert(io_addrs.end(), task_info_->output_data_addrs().begin(), task_info_->output_data_addrs().end()); | |||
| auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); | |||
| auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); | |||
| constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); | |||
| uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; | |||
| uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); | |||
| uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + | |||
| static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t); | |||
| aicpu::AicpuParamHead aicpu_param_head; | |||
| aicpu_param_head.length = args_size; | |||
| aicpu_param_head.ioAddrNum = io_addrs_num; | |||
| auto ext_info = task_info_->ext_info(); | |||
| uint32_t ext_size = ext_info.size(); | |||
| if (ext_info.empty()) { | |||
| aicpu_param_head.extInfoLength = 0; | |||
| aicpu_param_head.extInfoAddr = 0; | |||
| } else { | |||
| rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); | |||
| if (flag != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag); | |||
| return false; | |||
| } | |||
| flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size, | |||
| RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (flag != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); | |||
| return false; | |||
| } | |||
| GELOGI("ext info size: %u", ext_size); | |||
| aicpu_param_head.extInfoLength = ext_size; | |||
| aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | |||
| } | |||
| // Malloc device memory for args | |||
| rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task args data.", args_size) | |||
| // Memcpy AicpuParamHead | |||
| rt_ret = rtMemcpy(args_, sizeof(aicpu::AicpuParamHead), reinterpret_cast<void *>(&aicpu_param_head), | |||
| sizeof(aicpu::AicpuParamHead), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| // Memcpy io addrs | |||
| if (io_addrs_num != 0) { | |||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset), io_addrs_size, | |||
| reinterpret_cast<void *>(io_addrs.data()), io_addrs_size, RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| } | |||
| // Memcpy node def | |||
| auto size = task_info_->node_def().size(); | |||
| rt_ret = | |||
| rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t), | |||
| reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| // Memcpy node def | |||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), | |||
| task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), | |||
| task_info_->node_def().size(), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset); | |||
| auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||
| GELOGI( | |||
| "Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.", | |||
| args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag); | |||
| rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()), | |||
| reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, | |||
| args_size, nullptr, stream_, dump_flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("Distribute AicpuTask end."); | |||
| return true; | |||
| } | |||
| void AicpuTask::ReleaseRtMem(void **ptr) noexcept { | |||
| if (ptr == nullptr || *ptr == nullptr) { | |||
| return; | |||
| } | |||
| rtError_t rt_ret = rtFree(*ptr); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "ReleaseRtMem failed, ret: 0x%X", rt_ret); | |||
| return; | |||
| } | |||
| *ptr = nullptr; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::AICPU, AicpuTask, AicpuTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,50 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_AICPU_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_AICPU_TASK_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class AicpuTask : public TaskRepeater<AicpuTaskInfo> { | |||
| public: | |||
| AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info); | |||
| ~AicpuTask() override; | |||
| bool Distribute() override; | |||
| void *Args() override { return input_output_addr_; } | |||
| std::string task_name() const override { return task_info_->op_name(); } | |||
| private: | |||
| static void ReleaseRtMem(void **ptr) noexcept; | |||
| std::shared_ptr<AicpuTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *args_; | |||
| void *ext_info_; | |||
| void *input_output_addr_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_AICPU_TASK_H_ | |||
| @@ -1,160 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/cce_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| CceTask::CceTask(const ModelContext &model_context, const std::shared_ptr<CceTaskInfo> &task_info) | |||
| : TaskRepeater<CceTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| stub_func_(nullptr), | |||
| args_(nullptr), | |||
| sm_desc_(nullptr), | |||
| flowtable_(nullptr), | |||
| is_flowtable_(false) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| if (stream_list.size() == 1) { | |||
| stream_ = stream_list[0]; | |||
| } else if (stream_list.size() > task_info->stream_id()) { | |||
| stream_ = stream_list[task_info->stream_id()]; | |||
| } else { | |||
| GELOGW("index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); | |||
| } | |||
| } | |||
| CceTask::~CceTask() { | |||
| FreeRtMem(&args_); | |||
| FreeRtMem(&flowtable_); | |||
| rtError_t ret = (sm_desc_ != nullptr) ? rtMemFreeManaged(sm_desc_) : RT_ERROR_NONE; | |||
| if (ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||
| } | |||
| sm_desc_ = nullptr; | |||
| } | |||
| void CceTask::FreeRtMem(void **ptr) noexcept { | |||
| if (ptr == nullptr || *ptr == nullptr) { | |||
| return; | |||
| } | |||
| rtError_t ret = rtFree(*ptr); | |||
| if (ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||
| } | |||
| *ptr = nullptr; | |||
| } | |||
| bool CceTask::Distribute() { | |||
| GELOGI("Distribute CCETask start."); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream_ is null!"); | |||
| return false; | |||
| } | |||
| // Get stub_func | |||
| if (task_info_->stub_func().empty()) { | |||
| GELOGE(PARAM_INVALID, "kernel_info->stub_func is empty!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtGetFunctionByName failed, ret: 0x%X", rt_ret); | |||
| stub_func_ = nullptr; | |||
| return false; | |||
| } | |||
| GELOGI("CCETask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_); | |||
| // Flowtable | |||
| if (is_flowtable_) { | |||
| rt_ret = rtMalloc(&flowtable_, task_info_->flow_table().size(), RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task information.", task_info_->flow_table().size()) | |||
| rt_ret = rtMemcpy(flowtable_, task_info_->flow_table().size(), task_info_->flow_table().data(), | |||
| task_info_->flow_table().size(), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| // Modify flowtable addr in args | |||
| auto args = const_cast<uint8_t *>(task_info_->args().data()); | |||
| auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data())); | |||
| if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { | |||
| GELOGE(FAILED, "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||
| static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); | |||
| return false; | |||
| } | |||
| *(reinterpret_cast<uintptr_t *>(args + task_offset[0])) = reinterpret_cast<uintptr_t>(flowtable_); | |||
| } | |||
| // Args | |||
| rt_ret = rtMalloc(&args_, task_info_->args_size(), RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task information.", task_info_->args_size()) | |||
| rt_ret = rtMemcpy(args_, task_info_->args_size(), task_info_->args().data(), task_info_->args_size(), | |||
| RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| // L2 sm_desc | |||
| if (!task_info_->sm_desc().empty()) { | |||
| rt_ret = rtMemAllocManaged(&sm_desc_, task_info_->sm_desc().size(), RT_MEMORY_SPM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), task_info_->sm_desc().data(), | |||
| task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| } | |||
| // Kernel launch | |||
| rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, task_info_->args_size(), | |||
| static_cast<rtSmDesc_t *>(sm_desc_), stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::CCE, CceTask, CceTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,47 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_CCE_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_CCE_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class CceTask : public TaskRepeater<CceTaskInfo> { | |||
| public: | |||
| CceTask(const ModelContext &model_context, const std::shared_ptr<CceTaskInfo> &task_info); | |||
| ~CceTask() override; | |||
| bool Distribute() override; | |||
| static void FreeRtMem(void **ptr) noexcept; | |||
| private: | |||
| std::shared_ptr<CceTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *stub_func_; | |||
| void *args_; | |||
| void *sm_desc_; | |||
| void *flowtable_; | |||
| bool is_flowtable_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_CCE_TASK_H_ | |||
| @@ -1,61 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/event_record_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| EventRecordTask::EventRecordTask(const ModelContext &model_context, | |||
| const std::shared_ptr<EventRecordTaskInfo> &task_info) | |||
| : TaskRepeater<EventRecordTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| event_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto event_list = model_context.event_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t event_id = task_info->event_id(); | |||
| if (stream_id >= stream_list.size() || event_id >= event_list.size()) { | |||
| GELOGW("stream_list size:%zu, stream_id:%u, event_list size:%zu, event_id:%u", stream_list.size(), stream_id, | |||
| event_list.size(), event_id); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| event_ = event_list[event_id]; | |||
| } | |||
| EventRecordTask::~EventRecordTask() {} | |||
| bool EventRecordTask::Distribute() { | |||
| GELOGI("EventRecordTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_, | |||
| task_info_->stream_id(), task_info_->event_id()); | |||
| rtError_t rt_ret = rtEventRecord(event_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("Distribute end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::EVENT_RECORD, EventRecordTask, EventRecordTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,41 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_EVENT_RECORD_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_EVENT_RECORD_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> { | |||
| public: | |||
| EventRecordTask(const ModelContext &model_context, const std::shared_ptr<EventRecordTaskInfo> &task_info); | |||
| ~EventRecordTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<EventRecordTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| rtEvent_t event_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_EVENT_RECORD_TASK_H_ | |||
| @@ -1,67 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/event_wait_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| EventWaitTask::EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info) | |||
| : TaskRepeater<EventWaitTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| event_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto event_list = model_context.event_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t event_id = task_info->event_id(); | |||
| if (stream_id >= stream_list.size() || event_id >= event_list.size()) { | |||
| GELOGW("stream_list size:%zu, stream_id:%u, event_list size:%zu, event_id:%u", stream_list.size(), stream_id, | |||
| event_list.size(), event_id); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| event_ = event_list[event_id]; | |||
| } | |||
| EventWaitTask::~EventWaitTask() {} | |||
| bool EventWaitTask::Distribute() { | |||
| GELOGI("EventWaitTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_, | |||
| task_info_->stream_id(), task_info_->event_id()); | |||
| rtError_t rt_ret = rtStreamWaitEvent(stream_, event_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtStreamWaitEvent failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtEventReset(event_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtEventReset failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("Distribute end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::EVENT_WAIT, EventWaitTask, EventWaitTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,41 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_EVENT_WAIT_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_EVENT_WAIT_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> { | |||
| public: | |||
| EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info); | |||
| ~EventWaitTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<EventWaitTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| rtEvent_t event_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_EVENT_WAIT_TASK_H_ | |||
| @@ -1,268 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/hccl_task.h" | |||
| #include <algorithm> | |||
| #include "ge_runtime/task/task_factory.h" | |||
| #include "common/opskernel/ops_kernel_info_store.h" | |||
| #include "common/opskernel/ge_task_info.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<HcclTask::StreamGuard>>>> | |||
| HcclTask::model_stream_mapping_; | |||
| std::mutex HcclTask::model_stream_mapping_mutex_; | |||
| HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info) | |||
| : TaskRepeater<HcclTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| workspace_mem_(nullptr), | |||
| rt_model_handle_(nullptr), | |||
| priority_(0), | |||
| secondary_stream_list_() { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| priority_ = model_context.priority(); | |||
| rt_model_handle_ = model_context.rt_model_handle(); | |||
| auto stream_list = model_context.stream_list(); | |||
| if (stream_list.size() == 1) { | |||
| stream_ = stream_list[0]; | |||
| } else if (stream_list.size() > task_info->stream_id()) { | |||
| stream_ = stream_list[task_info->stream_id()]; | |||
| } else { | |||
| GELOGW("Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); | |||
| } | |||
| } | |||
| HcclTask::~HcclTask() { | |||
| if (workspace_mem_ != nullptr) { | |||
| rtError_t rt_ret = rtFree(workspace_mem_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtFree workspace_mem_ failed! ret: 0x%X.", rt_ret); | |||
| } | |||
| workspace_mem_ = nullptr; | |||
| } | |||
| } | |||
| bool HcclTask::Distribute() { | |||
| // Ops kernel info store | |||
| // Get privateDef and opsKernelStorePtr | |||
| GELOGI("Get custom info in modelTaskDef"); | |||
| void *ops_kernel_store = task_info_->ops_kernel_store(); | |||
| OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<OpsKernelInfoStore *>(ops_kernel_store); | |||
| if (ops_kernel_store == nullptr) { | |||
| GELOGE(PARAM_INVALID, "No hcom distribute function ptr and no ops kernel store."); | |||
| return false; | |||
| } | |||
| char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data())); | |||
| auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size()); | |||
| GELOGI("The first address of the custom info, privateDef=%p", private_def); | |||
| SetSecondaryStream(); | |||
| if (task_info_->workspace_size() > 0) { | |||
| rtError_t rt_ret = rtMalloc(&workspace_mem_, task_info_->workspace_size(), RT_MEMORYINFO_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| } | |||
| GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl."); | |||
| GETaskInfo ge_task; | |||
| ge_task.id = 0; | |||
| ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | |||
| ge_task.stream = stream_; | |||
| ge_task.kernelHcclInfo = std::vector<GETaskKernelHcclInfo>(1); | |||
| ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | |||
| ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | |||
| ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | |||
| ge_task.kernelHcclInfo[0].workSpaceAddr = workspace_mem_; | |||
| ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size(); | |||
| ge_task.kernelHcclInfo[0].count = task_info_->count(); | |||
| ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type()); | |||
| ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type()); | |||
| ge_task.kernelHcclInfo[0].rootId = task_info_->root_id(); | |||
| std::vector<rtStream_t> secondary_stream_list; | |||
| std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(), | |||
| std::back_inserter(secondary_stream_list), | |||
| [](const std::shared_ptr<StreamGuard> &stream) -> rtStream_t { return stream->GetStream(); }); | |||
| ge_task.kernelHcclInfo[0].hcclStreamList = secondary_stream_list; | |||
| ge_task.privateDef = private_def; | |||
| ge_task.privateDefLen = private_def_len; | |||
| ge_task.opsKernelStorePtr = ops_kernel_store; | |||
| auto result = ops_kernel_info_store->LoadTask(ge_task); | |||
| // tagHcclResult::HCCL_SUCCESS is 0 | |||
| if (result != 0) { | |||
| GELOGE(INTERNAL_ERROR, "davinci_model : load task fail, return ret: %u", result); | |||
| return false; | |||
| } | |||
| GELOGI("Call function LoadTask end."); | |||
| return true; | |||
| } | |||
| bool HcclTask::SetSecondaryStream() { | |||
| const uint32_t master_stream_id = task_info_->stream_id(); | |||
| const int64_t hccl_secondary_stream_num = task_info_->hccl_stream_num(); | |||
| Status ret; | |||
| std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_); | |||
| if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | |||
| GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %u.", rt_model_handle_, master_stream_id); | |||
| ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | |||
| if (!ret) { | |||
| GELOGE(RT_FAILED, "Create hccl stream failed."); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map = | |||
| model_stream_mapping_.at(rt_model_handle_); | |||
| auto iter = master_secondary_stream_map.find(master_stream_id); | |||
| if (iter != master_secondary_stream_map.end()) { | |||
| std::vector<std::weak_ptr<StreamGuard>> &secondary_stream_vec = iter->second; | |||
| auto lock_weak_ptr = [&secondary_stream_vec, this](int64_t index) -> bool { | |||
| auto stream = secondary_stream_vec[index].lock(); | |||
| if (stream == nullptr) { | |||
| rtStream_t new_stream = nullptr; | |||
| bool ret = CreateStream(rt_model_handle_, &new_stream); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "CreateStream failed."); | |||
| return false; | |||
| } | |||
| stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream); | |||
| GE_RT_FALSE_CHECK_NOTNULL(stream); | |||
| secondary_stream_vec[index] = stream; | |||
| } | |||
| secondary_stream_list_.push_back(stream); | |||
| return true; | |||
| }; | |||
| if (static_cast<size_t>(hccl_secondary_stream_num) <= secondary_stream_vec.size()) { | |||
| GELOGI("Number of secondary stream is enough to be reused."); | |||
| for (int64_t i = 0; i < hccl_secondary_stream_num; ++i) { | |||
| if (!lock_weak_ptr(i)) { | |||
| GELOGE(FAILED, "Lock weak ptr failed."); | |||
| return false; | |||
| } | |||
| } | |||
| } else { | |||
| GELOGI("Need to reuse secondary stream and create new secondary stream."); | |||
| size_t created_stream_num = secondary_stream_vec.size(); | |||
| for (size_t i = 0; i < secondary_stream_vec.size(); ++i) { | |||
| if (!lock_weak_ptr(i)) { | |||
| GELOGE(FAILED, "Lock weak ptr failed."); | |||
| return false; | |||
| } | |||
| } | |||
| ret = CreateStream(hccl_secondary_stream_num - created_stream_num, master_stream_id); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(RT_FAILED, "Create hccl stream failed."); | |||
| return false; | |||
| } | |||
| } | |||
| GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num); | |||
| } else { | |||
| GELOGI("Need to create secondary stream for %s with new mainstream %u.", task_info_->op_name().c_str(), | |||
| master_stream_id); | |||
| ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | |||
| if (!ret) { | |||
| GELOGE(RT_FAILED, "Create hccl stream failed."); | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool HcclTask::CreateStream(int64_t stream_num, int64_t master_stream_id) { | |||
| GELOGI("Start to create %ld hccl secondary stream.", stream_num); | |||
| for (int64_t i = 0; i < stream_num; ++i) { | |||
| rtStream_t stream = nullptr; | |||
| bool ret = CreateStream(rt_model_handle_, &stream); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "CreateStream failed."); | |||
| return false; | |||
| } | |||
| GELOGD("hccl_stream addr is=%p", stream); | |||
| auto shared_stream = std::make_shared<StreamGuard>(rt_model_handle_, stream); | |||
| if (shared_stream == nullptr) { | |||
| GELOGE(FAILED, "MakeShared failed."); | |||
| return false; | |||
| } | |||
| SaveHcclSecondaryStream(master_stream_id, shared_stream); | |||
| secondary_stream_list_.push_back(shared_stream); | |||
| } | |||
| GELOGI("CreateStream success."); | |||
| return true; | |||
| } | |||
| bool HcclTask::CreateStream(rtModel_t model, rtStream_t *stream) const { | |||
| if (stream == nullptr) { | |||
| GELOGE(FAILED, "Output param stream is null."); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtStreamCreateWithFlags(stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| // Create secondary stream, inactive by default, activated by hccl | |||
| rt_ret = rtModelBindStream(model, *stream, RT_MODEL_WAIT_ACTIVE_STREAM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void HcclTask::SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream) { | |||
| if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | |||
| model_stream_mapping_.emplace(rt_model_handle_, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>()); | |||
| } | |||
| std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map = | |||
| model_stream_mapping_.at(rt_model_handle_); | |||
| master_secondary_stream_map[master_stream_id].emplace_back(stream); | |||
| } | |||
| HcclTask::StreamGuard::~StreamGuard() { | |||
| rtError_t rt_ret = rtModelUnbindStream(model_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Unbind stream from model failed!"); | |||
| return; | |||
| } | |||
| rt_ret = rtStreamDestroy(stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Destroy stream failed!"); | |||
| return; | |||
| } | |||
| } | |||
| REGISTER_TASK(TaskInfoType::HCCL, HcclTask, HcclTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,69 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_HCCL_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_HCCL_TASK_H_ | |||
| #include <memory> | |||
| #include <set> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <mutex> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class HcclTask : public TaskRepeater<HcclTaskInfo> { | |||
| public: | |||
| HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info); | |||
| ~HcclTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| class StreamGuard; | |||
| bool SetSecondaryStream(); | |||
| bool CreateStream(int64_t stream_num, int64_t master_stream_id); | |||
| bool CreateStream(rtModel_t model, rtStream_t *stream) const; | |||
| void SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream); | |||
| std::shared_ptr<HcclTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *workspace_mem_; | |||
| rtModel_t rt_model_handle_; | |||
| int32_t priority_; | |||
| std::vector<std::shared_ptr<StreamGuard>> secondary_stream_list_; | |||
| // map<key: model pointer, value: map<key: primary stream id, value: vector<secondary stream pointer>>> | |||
| static std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>> model_stream_mapping_; | |||
| static std::mutex model_stream_mapping_mutex_; | |||
| }; | |||
| class HcclTask::StreamGuard { | |||
| public: | |||
| StreamGuard(rtModel_t model, rtStream_t stream) : model_(model), stream_(stream) {} | |||
| ~StreamGuard(); | |||
| rtStream_t GetStream() const { return stream_; } | |||
| private: | |||
| rtModel_t model_; | |||
| rtStream_t stream_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_HCCL_TASK_H_ | |||
| @@ -1,117 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_goto_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| #include "framework/common/util.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t label_id = task_info->label_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
| GELOGW("Stream/Label id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelGotoTask::~LabelGotoTask() { | |||
| GE_FREE_RT_LOG(label_info_); | |||
| GE_FREE_RT_LOG(index_value_); | |||
| } | |||
| bool LabelGotoTask::Distribute() { | |||
| GELOGI("LabelGotoTask Distribute start."); | |||
| if (!CheckParamValid()) { | |||
| return false; | |||
| } | |||
| const std::vector<void *> label_list = { label_ }; | |||
| rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| uint64_t branch_index = 0; | |||
| rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||
| rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| bool LabelGotoTask::CheckParamValid() { | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| if (label_info_ != nullptr) { | |||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
| return false; | |||
| } | |||
| if (index_value_ != nullptr) { | |||
| GELOGE(PARAM_INVALID, "index_value_ has dirty data."); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,45 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
| public: | |||
| LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||
| ~LabelGotoTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| bool CheckParamValid(); | |||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
| void *stream_{nullptr}; | |||
| void *label_{nullptr}; | |||
| void *label_info_{nullptr}; | |||
| void *index_value_{nullptr}; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| @@ -1,70 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_set_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t label_id = task_info->label_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
| GELOGW("Stream/Label id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelSetTask::~LabelSetTask() {} | |||
| bool LabelSetTask::Distribute() { | |||
| GELOGI("LabelSetTask Distribute start."); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtLabelSet(label_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,41 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||
| public: | |||
| LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||
| ~LabelSetTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<LabelSetTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| @@ -1,131 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_switch_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
| const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| all_label_resource_(), | |||
| label_info_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| all_label_resource_ = model_context.label_list(); | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| if (stream_id >= stream_list.size()) { | |||
| GELOGW("Stream id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| } | |||
| LabelSwitchTask::~LabelSwitchTask() { | |||
| if (label_info_ != nullptr) { | |||
| rtError_t rt_ret = rtFree(label_info_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||
| } | |||
| label_info_ = nullptr; | |||
| } | |||
| } | |||
| bool LabelSwitchTask::Distribute() { | |||
| GELOGI("LabelSwitchTask Distribute start."); | |||
| if (!CheckParamValid()) { | |||
| return false; | |||
| } | |||
| const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||
| std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||
| for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||
| uint32_t label_index = label_index_list[i]; | |||
| if (label_index >= all_label_resource_.size()) { | |||
| GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||
| all_label_resource_.size()); | |||
| return false; | |||
| } | |||
| label_list[i] = all_label_resource_[label_index]; | |||
| GELOGI("Case %zu: label id %zu.", i, (size_t)label_index); | |||
| } | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||
| rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| bool LabelSwitchTask::CheckParamValid() { | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (task_info_->label_list().empty()) { | |||
| GELOGE(PARAM_INVALID, "label_list is empty."); | |||
| return false; | |||
| } | |||
| if (task_info_->label_size() != task_info_->label_list().size()) { | |||
| GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||
| task_info_->label_size()); | |||
| return false; | |||
| } | |||
| if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||
| GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||
| return false; | |||
| } | |||
| if (label_info_ != nullptr) { | |||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,44 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||
| public: | |||
| LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||
| ~LabelSwitchTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| bool CheckParamValid(); | |||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||
| void *stream_; | |||
| std::vector<void *> all_label_resource_; | |||
| void *label_info_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| @@ -1,57 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/memcpy_async_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| MemcpyAsyncTask::MemcpyAsyncTask(const ModelContext &model_context, | |||
| const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info) | |||
| : TaskRepeater<MemcpyAsyncTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| if (stream_id >= stream_list.size()) { | |||
| GELOGW("Stream id invalid"); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| } | |||
| MemcpyAsyncTask::~MemcpyAsyncTask() {} | |||
| bool MemcpyAsyncTask::Distribute() { | |||
| GELOGI("MemcpyAsyncTask Distribute start."); | |||
| GELOGI("dst_max:%lu, count:%lu, kind:%u.", task_info_->dst_max(), task_info_->count(), task_info_->kind()); | |||
| rtError_t rt_ret = rtMemcpyAsync(task_info_->dst(), task_info_->dst_max(), task_info_->src(), task_info_->count(), | |||
| static_cast<rtMemcpyKind_t>(task_info_->kind()), stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end"); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::MEMCPY_ASYNC, MemcpyAsyncTask, MemcpyAsyncTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,40 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class MemcpyAsyncTask : public TaskRepeater<MemcpyAsyncTaskInfo> { | |||
| public: | |||
| MemcpyAsyncTask(const ModelContext &model_context, const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info); | |||
| ~MemcpyAsyncTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<MemcpyAsyncTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ | |||
| @@ -1,55 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/profiler_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info) | |||
| : TaskRepeater<ProfilerTraceTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| if (stream_id >= stream_list.size()) { | |||
| GELOGW("Stream id invalid"); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| } | |||
| ProfilerTask::~ProfilerTask() {} | |||
| bool ProfilerTask::Distribute() { | |||
| GELOGI("ProfilerTask Distribute start."); | |||
| GELOGI("logid = %lu, notify = %d, flat = %u.", task_info_->log_id(), task_info_->notify(), task_info_->flat()); | |||
| rtError_t rt_ret = rtProfilerTrace(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end"); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::PROFILER_TRACE, ProfilerTask, ProfilerTraceTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,40 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_PROFILER_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_PROFILER_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class ProfilerTask : public TaskRepeater<ProfilerTraceTaskInfo> { | |||
| public: | |||
| ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info); | |||
| ~ProfilerTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<ProfilerTraceTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_PROFILER_TASK_H_ | |||
| @@ -1,60 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/stream_active_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| StreamActiveTask::StreamActiveTask(const ModelContext &model_context, | |||
| const std::shared_ptr<StreamActiveTaskInfo> &task_info) | |||
| : TaskRepeater<StreamActiveTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| active_stream_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t active_stream_id = task_info->active_stream_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u, active stream id:%u", stream_list.size(), stream_id, active_stream_id); | |||
| if (stream_id >= stream_list.size() || active_stream_id >= stream_list.size()) { | |||
| GELOGW("Stream id invalid"); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| active_stream_ = stream_list[active_stream_id]; | |||
| } | |||
| StreamActiveTask::~StreamActiveTask() {} | |||
| bool StreamActiveTask::Distribute() { | |||
| GELOGI("Distribute start"); | |||
| GELOGI("Stream %u active %u.", task_info_->stream_id(), task_info_->active_stream_id()); | |||
| rtError_t rt_ret = rtStreamActive(active_stream_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end"); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::STREAM_ACTIVE, StreamActiveTask, StreamActiveTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,41 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class StreamActiveTask : public TaskRepeater<StreamActiveTaskInfo> { | |||
| public: | |||
| StreamActiveTask(const ModelContext &model_context, const std::shared_ptr<StreamActiveTaskInfo> &task_info); | |||
| ~StreamActiveTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<StreamActiveTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| rtStream_t active_stream_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ | |||
| @@ -1,82 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/stream_switch_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| StreamSwitchTask::StreamSwitchTask(const ModelContext &model_context, | |||
| const std::shared_ptr<StreamSwitchTaskInfo> &task_info) | |||
| : TaskRepeater<StreamSwitchTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| stream_list_() { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| stream_list_ = model_context.stream_list(); | |||
| if (stream_list_.size() == 1) { | |||
| stream_ = stream_list_[0]; | |||
| } else if (stream_list_.size() > task_info->stream_id()) { | |||
| stream_ = stream_list_[task_info->stream_id()]; | |||
| } else { | |||
| GELOGW("Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list_.size()); | |||
| } | |||
| } | |||
| StreamSwitchTask::~StreamSwitchTask() {} | |||
| bool StreamSwitchTask::Distribute() { | |||
| GELOGI("Init StreamSwitchTask start."); | |||
| GELOGI("Stream %u active %ld.", task_info_->stream_id(), task_info_->true_stream_id()); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream_ is null!"); | |||
| return false; | |||
| } | |||
| if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) { | |||
| GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(), | |||
| stream_list_.size()); | |||
| return false; | |||
| } | |||
| void *input = reinterpret_cast<void *>(task_info_->input_addr()); | |||
| rtCondition_t cond = static_cast<rtCondition_t>(task_info_->cond()); | |||
| void *value = reinterpret_cast<void *>(task_info_->value_addr()); | |||
| rtStream_t true_stream = stream_list_[task_info_->true_stream_id()]; | |||
| rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type()); | |||
| GELOGI("InitStreamSwitchTask, cond:%d, trueStream:%p, trueStreamID:%ld, datatype:%ld.", cond, true_stream, | |||
| task_info_->true_stream_id(), task_info_->data_type()); | |||
| GELOGI("StreamSwitchTask Distribute Start."); | |||
| rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("Distribute StreamSwitch, cond:%d, trueStream:%p, datatype:%ld.", cond, true_stream, task_info_->data_type()); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::STREAM_SWITCH, StreamSwitchTask, StreamSwitchTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,43 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> { | |||
| public: | |||
| StreamSwitchTask(const ModelContext &model_context, const std::shared_ptr<StreamSwitchTaskInfo> &task_info); | |||
| ~StreamSwitchTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<StreamSwitchTaskInfo> task_info_; | |||
| void *stream_; | |||
| std::vector<rtStream_t> stream_list_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||
| @@ -1,58 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_TASK_H_ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "runtime/rt_model.h" | |||
| #include "ge_runtime/model_context.h" | |||
| #include "ge_runtime/task_info.h" | |||
| #include "external/runtime/rt_error_codes.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class Task { | |||
| public: | |||
| Task() {} | |||
| virtual ~Task() {} | |||
| virtual bool Distribute() = 0; | |||
| virtual void *Args() { return nullptr; } | |||
| virtual std::string task_name() const { return ""; } | |||
| }; | |||
| template <class T> | |||
| class TaskRepeater : public Task { | |||
| static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); | |||
| public: | |||
| TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {} | |||
| virtual ~TaskRepeater() {} | |||
| virtual bool Distribute() = 0; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_TASK_H_ | |||
| @@ -1,87 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||
| #define GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||
| #include <functional> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "ge_runtime/task_info.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class Task; | |||
| class ModelContext; | |||
| using TASK_CREATOR_FUN = std::function<std::shared_ptr<Task>(const ModelContext &, std::shared_ptr<TaskInfo>)>; | |||
| class TaskFactory { | |||
| private: | |||
| TaskFactory() {} | |||
| ~TaskFactory() {} | |||
| void RegisterCreator(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { | |||
| if (creator_map_.find(type) != creator_map_.end()) { | |||
| GELOGW("Creator type %d already exist", static_cast<int32_t>(type)); | |||
| } | |||
| creator_map_[type] = func; | |||
| } | |||
| std::map<TaskInfoType, TASK_CREATOR_FUN> creator_map_; | |||
| public: | |||
| static TaskFactory &GetInstance() { | |||
| static TaskFactory instance; | |||
| return instance; | |||
| } | |||
| std::shared_ptr<Task> Create(const ModelContext &model_context, std::shared_ptr<TaskInfo> &task_info) const { | |||
| if (task_info == nullptr) { | |||
| GELOGE(FAILED, "task_info is null."); | |||
| return nullptr; | |||
| } | |||
| auto iter = creator_map_.find(task_info->type()); | |||
| if (iter == creator_map_.end()) { | |||
| GELOGE(FAILED, "Unknow task type %d", static_cast<int32_t>(task_info->type())); | |||
| return nullptr; | |||
| } | |||
| return iter->second(model_context, task_info); | |||
| } | |||
| class Register { | |||
| public: | |||
| Register(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { | |||
| GELOGI("regist type %d", static_cast<int32_t>(type)); | |||
| TaskFactory::GetInstance().RegisterCreator(type, func); | |||
| } | |||
| ~Register() {} | |||
| }; | |||
| }; | |||
| #define REGISTER_TASK(type, task_clazz, task_info_clazz) \ | |||
| TaskFactory::Register g_##task_clazz##_register( \ | |||
| type, [](const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) -> std::shared_ptr<Task> { \ | |||
| std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ | |||
| return std::make_shared<task_clazz>(model_context, concrete_task_info); \ | |||
| }); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||
| @@ -1,112 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/tbe_task.h" | |||
| #include <vector> | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info) | |||
| : TaskRepeater<TbeTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| stub_func_(nullptr), | |||
| args_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| if (stream_list.size() == 1) { | |||
| stream_ = stream_list[0]; | |||
| } else if (stream_list.size() > task_info->stream_id()) { | |||
| stream_ = stream_list[task_info->stream_id()]; | |||
| } else { | |||
| GELOGE(PARAM_INVALID, "Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); | |||
| return; | |||
| } | |||
| } | |||
| TbeTask::~TbeTask() { | |||
| if (args_ != nullptr) { | |||
| rtError_t rt_ret = rtFree(args_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||
| } | |||
| args_ = nullptr; | |||
| } | |||
| } | |||
| bool TbeTask::Distribute() { | |||
| GELOGI("InitTbeTask start."); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream_ is null!"); | |||
| return false; | |||
| } | |||
| // Get stub_func | |||
| if (task_info_->stub_func().empty()) { | |||
| GELOGE(PARAM_INVALID, "kernel_info->stub_func is empty!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtGetFunctionByName failed, ret: %d", static_cast<int32_t>(rt_ret)); | |||
| stub_func_ = nullptr; | |||
| return false; | |||
| } | |||
| GELOGI("TbeTask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_); | |||
| // Get args | |||
| std::vector<void *> tensor_device_addrs; | |||
| tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->input_data_addrs().begin(), | |||
| task_info_->input_data_addrs().end()); | |||
| tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->output_data_addrs().begin(), | |||
| task_info_->output_data_addrs().end()); | |||
| tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->workspace_addrs().begin(), | |||
| task_info_->workspace_addrs().end()); | |||
| auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *)); | |||
| rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtMalloc failed, ret: %d", static_cast<int32_t>(rt_ret)); | |||
| return false; | |||
| } | |||
| GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task args data.", args_size) | |||
| rt_ret = rtMemcpy(args_, args_size, reinterpret_cast<void *>(tensor_device_addrs.data()), args_size, | |||
| RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtMemcpy fail, ret 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTbeTask start."); | |||
| auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||
| rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::TBE, TbeTask, TbeTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,46 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_TBE_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_TBE_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class TbeTask : public TaskRepeater<TbeTaskInfo> { | |||
| public: | |||
| TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info); | |||
| ~TbeTask() override; | |||
| bool Distribute() override; | |||
| void *Args() override { return args_; } | |||
| std::string task_name() const override { return task_info_->op_name(); } | |||
| private: | |||
| std::shared_ptr<TbeTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *stub_func_; | |||
| void *args_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_TBE_TASK_H_ | |||
| @@ -1,113 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ | |||
| #define INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ge_runtime/op_info.h" | |||
| #include "ge_runtime/task_info.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class DavinciModel { | |||
| public: | |||
| DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, | |||
| const std::vector<std::shared_ptr<OpInfo>> &data_info_list, | |||
| const std::vector<std::shared_ptr<OpInfo>> &output_info_list, | |||
| const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, | |||
| const std::vector<model_runner::OpInfoPtr> &variable_info_list, | |||
| const std::vector<uint32_t> &wait_active_stream_list, | |||
| const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | |||
| uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0, | |||
| uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0, | |||
| int32_t priority = 0) | |||
| : task_info_list_(task_info_list), | |||
| data_info_list_(data_info_list), | |||
| output_info_list_(output_info_list), | |||
| constant_info_list_(constant_info_list), | |||
| variable_info_list_(variable_info_list), | |||
| wait_active_stream_list_(wait_active_stream_list), | |||
| force_copy_stream_list_(force_copy_stream_list), | |||
| mem_size_(mem_size), | |||
| weight_size_(weight_size), | |||
| var_size_(var_size), | |||
| logic_mem_base_(logic_mem_base), | |||
| logic_weight_base_(logic_weight_base), | |||
| logic_var_base_(logic_var_base), | |||
| stream_num_(stream_num), | |||
| batch_num_(batch_num), | |||
| event_num_(event_num), | |||
| priority_(priority) {} | |||
| ~DavinciModel() {} | |||
| uint64_t GetMemSize() const { return mem_size_; } | |||
| uint64_t GetWeightSize() const { return weight_size_; } | |||
| uint64_t GetVarSize() const { return var_size_; } | |||
| uintptr_t GetLogicMemBase() const { return logic_mem_base_; } | |||
| uintptr_t GetLogicWeightBase() const { return logic_weight_base_; } | |||
| uintptr_t GetLogicVarBase() const { return logic_var_base_; } | |||
| uint32_t GetStreamNum() const { return stream_num_; } | |||
| uint32_t GetBatchNum() const { return batch_num_; } | |||
| uint32_t GetEventNum() const { return event_num_; } | |||
| const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } | |||
| const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } | |||
| int32_t GetPriority() const { return priority_; } | |||
| const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } | |||
| const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } | |||
| const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } | |||
| const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } | |||
| const std::vector<model_runner::OpInfoPtr> &GetVariableInfoList() const { return variable_info_list_; } | |||
| private: | |||
| std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | |||
| std::vector<std::shared_ptr<OpInfo>> data_info_list_; | |||
| std::vector<std::shared_ptr<OpInfo>> output_info_list_; | |||
| std::vector<std::shared_ptr<OpInfo>> constant_info_list_; | |||
| std::vector<model_runner::OpInfoPtr> variable_info_list_; | |||
| std::vector<uint32_t> wait_active_stream_list_; | |||
| std::vector<uint32_t> force_copy_stream_list_; | |||
| uint64_t mem_size_; | |||
| uint64_t weight_size_; | |||
| uint64_t var_size_; | |||
| uintptr_t logic_mem_base_; | |||
| uintptr_t logic_weight_base_; | |||
| uintptr_t logic_var_base_; | |||
| uint32_t stream_num_; | |||
| uint32_t batch_num_; | |||
| uint32_t event_num_; | |||
| int32_t priority_; | |||
| // Disable to copy constructor and assignment operator | |||
| DavinciModel &operator=(const DavinciModel &) = delete; | |||
| DavinciModel(const DavinciModel &) = delete; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ | |||
| @@ -1,68 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ | |||
| #define INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "common/ge_types.h" | |||
| #include "ge_runtime/davinci_model.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class RuntimeModel; | |||
| using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||
| class ModelRunner { | |||
| public: | |||
| static ModelRunner &Instance(); | |||
| bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||
| std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener); | |||
| bool DistributeTask(uint32_t model_id); | |||
| bool LoadModelComplete(uint32_t model_id); | |||
| const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | |||
| const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const; | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const; | |||
| void *GetModelHandle(uint32_t model_id) const; | |||
| bool UnloadModel(uint32_t model_id); | |||
| bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | |||
| bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||
| std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||
| std::vector<uint32_t> *output_format); | |||
| private: | |||
| ModelRunner() = default; | |||
| ~ModelRunner() = default; | |||
| std::unordered_map<uint32_t, std::shared_ptr<RuntimeModel>> runtime_models_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ | |||
| @@ -1,72 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ | |||
| #define INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| namespace ge { | |||
| namespace model_runner { | |||
| struct TensorInfo { | |||
| int64_t GetShapeSize() const { | |||
| int64_t res = 1; | |||
| if (dims.empty()) { | |||
| return 0; | |||
| } | |||
| for (auto dim : dims) { | |||
| res *= dim; | |||
| } | |||
| return res; | |||
| } | |||
| int64_t GetDim(uint32_t index) { | |||
| if (index >= dims.size()) { | |||
| return 0; | |||
| } | |||
| return dims[index]; | |||
| } | |||
| std::vector<int64_t> dims; | |||
| uint32_t datatype; | |||
| uint32_t format; | |||
| uint32_t real_dim_cnt; | |||
| uint32_t size; | |||
| bool is_output; | |||
| }; | |||
| struct OpInfo { | |||
| uint32_t index; | |||
| std::string name; | |||
| std::string type; | |||
| bool var_is_broadcast; | |||
| std::vector<uintptr_t> input_addrs; | |||
| std::vector<uintptr_t> output_addrs; | |||
| std::vector<TensorInfo> input_tensors; | |||
| std::vector<TensorInfo> output_tensors; | |||
| std::vector<TensorInfo> weight_tensors; | |||
| std::vector<std::string> src_name; | |||
| std::vector<int64_t> src_index; | |||
| std::string weight_data; | |||
| }; | |||
| using TensorInfoPtr = std::shared_ptr<TensorInfo>; | |||
| using OpInfoPtr = std::shared_ptr<OpInfo>; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ | |||
| @@ -1,405 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ | |||
| #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ | |||
| #include <stdint.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "cce/taskdown_api.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| enum TaskInfoType { | |||
| CCE = 0, | |||
| TBE, | |||
| AICPU, | |||
| LABEL_SET, | |||
| LABEL_SWITCH, | |||
| LABEL_GOTO, | |||
| EVENT_RECORD, | |||
| EVENT_WAIT, | |||
| FUSION_START, | |||
| FUSION_END, | |||
| HCCL, | |||
| PROFILER_TRACE, | |||
| MEMCPY_ASYNC, | |||
| STREAM_SWITCH, | |||
| STREAM_ACTIVE, | |||
| // Insert new task type here | |||
| REVSERVED = 23 | |||
| }; | |||
| class TaskInfo { | |||
| public: | |||
| virtual ~TaskInfo() {} | |||
| uint32_t stream_id() const { return stream_id_; } | |||
| TaskInfoType type() const { return type_; } | |||
| std::string op_name() const { return op_name_; } | |||
| bool dump_flag() const { return dump_flag_; } | |||
| protected: | |||
| TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) | |||
| : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} | |||
| private: | |||
| std::string op_name_; | |||
| uint32_t stream_id_; | |||
| TaskInfoType type_; | |||
| bool dump_flag_; | |||
| }; | |||
| class CceTaskInfo : public TaskInfo { | |||
| public: | |||
| CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, | |||
| uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size, | |||
| const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table, | |||
| const std::vector<uint8_t> &args_offset, bool is_flowtable) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false), | |||
| ctx_(ctx), | |||
| stub_func_(stub_func), | |||
| block_dim_(block_dim), | |||
| args_(args), | |||
| args_size_(args_size), | |||
| sm_desc_(sm_desc), | |||
| flow_table_(flow_table), | |||
| args_offset_(args_offset), | |||
| is_flowtable_(is_flowtable) {} | |||
| ~CceTaskInfo() override {} | |||
| cce::ccOpContext cc_context() const { return ctx_; } | |||
| std::string stub_func() const { return stub_func_; } | |||
| uint32_t block_dim() const { return block_dim_; } | |||
| const std::vector<uint8_t> &args() const { return args_; } | |||
| uint32_t args_size() const { return args_size_; } | |||
| const std::vector<uint8_t> &sm_desc() const { return sm_desc_; } | |||
| const std::vector<uint8_t> &flow_table() const { return flow_table_; } | |||
| const std::vector<uint8_t> &args_offset() const { return args_offset_; } | |||
| bool is_flowtable() const { return is_flowtable_; } | |||
| private: | |||
| cce::ccOpContext ctx_; | |||
| std::string stub_func_; | |||
| uint32_t block_dim_; | |||
| std::vector<uint8_t> args_; | |||
| uint32_t args_size_; | |||
| std::vector<uint8_t> sm_desc_; | |||
| std::vector<uint8_t> flow_table_; | |||
| std::vector<uint8_t> args_offset_; | |||
| bool is_flowtable_; | |||
| }; | |||
| class TbeTaskInfo : public TaskInfo { | |||
| public: | |||
| TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, | |||
| const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, | |||
| uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), | |||
| stub_func_(stub_func), | |||
| block_dim_(block_dim), | |||
| args_(args), | |||
| args_size_(args_size), | |||
| sm_desc_(sm_desc), | |||
| binary_(binary), | |||
| binary_size_(binary_size), | |||
| meta_data_(meta_data), | |||
| input_data_addrs_(input_data_addrs), | |||
| output_data_addrs_(output_data_addrs), | |||
| workspace_addrs_(workspace_addrs) {} | |||
| ~TbeTaskInfo() override {} | |||
| const std::string &stub_func() const { return stub_func_; } | |||
| uint32_t block_dim() const { return block_dim_; } | |||
| const std::vector<uint8_t> &args() const { return args_; } | |||
| uint32_t args_size() const { return args_size_; } | |||
| const std::vector<uint8_t> &sm_desc() const { return sm_desc_; } | |||
| void *binary() const { return binary_; } | |||
| uint32_t binary_size() const { return binary_size_; } | |||
| const std::vector<uint8_t> &meta_data() const { return meta_data_; } | |||
| const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||
| const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||
| const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; } | |||
| void SetBinary(void *binary, uint32_t binary_size) { | |||
| binary_ = binary; | |||
| binary_size_ = binary_size; | |||
| } | |||
| private: | |||
| std::string stub_func_; | |||
| uint32_t block_dim_; | |||
| std::vector<uint8_t> args_; | |||
| uint32_t args_size_; | |||
| std::vector<uint8_t> sm_desc_; | |||
| void *binary_; | |||
| uint32_t binary_size_; | |||
| std::vector<uint8_t> meta_data_; | |||
| std::vector<void *> input_data_addrs_; | |||
| std::vector<void *> output_data_addrs_; | |||
| std::vector<void *> workspace_addrs_; | |||
| }; | |||
| class AicpuTaskInfo : public TaskInfo { | |||
| public: | |||
| AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | |||
| const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | |||
| so_name_(so_name), | |||
| kernel_name_(kernel_name), | |||
| node_def_(node_def), | |||
| ext_info_(ext_info), | |||
| input_data_addrs_(input_data_addrs), | |||
| output_data_addrs_(output_data_addrs) {} | |||
| ~AicpuTaskInfo() override {} | |||
| const std::string &so_name() const { return so_name_; } | |||
| const std::string &kernel_name() const { return kernel_name_; } | |||
| const std::string &node_def() const { return node_def_; } | |||
| const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||
| const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||
| const std::string &ext_info() const { return ext_info_; } | |||
| private: | |||
| std::string so_name_; | |||
| std::string kernel_name_; | |||
| std::string node_def_; | |||
| std::string ext_info_; | |||
| std::vector<void *> input_data_addrs_; | |||
| std::vector<void *> output_data_addrs_; | |||
| }; | |||
| class LabelSetTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} | |||
| ~LabelSetTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelGotoTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} | |||
| ~LabelGotoTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelSwitchTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, | |||
| const std::vector<uint32_t> &label_list, void *cond) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), | |||
| label_size_(label_size), | |||
| label_list_(label_list), | |||
| cond_(cond) {} | |||
| ~LabelSwitchTaskInfo() override {} | |||
| uint32_t label_size() const { return label_size_; } | |||
| const std::vector<uint32_t> &label_list() const { return label_list_; } | |||
| void *cond() const { return cond_; } | |||
| private: | |||
| uint32_t label_size_; | |||
| std::vector<uint32_t> label_list_; | |||
| void *cond_; | |||
| }; | |||
| class EventTaskInfo : public TaskInfo { | |||
| public: | |||
| uint32_t event_id() const { return event_id_; } | |||
| protected: | |||
| EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||
| : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} | |||
| ~EventTaskInfo() override {} | |||
| uint32_t event_id_; | |||
| }; | |||
| class EventRecordTaskInfo : public EventTaskInfo { | |||
| public: | |||
| EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||
| ~EventRecordTaskInfo() override {} | |||
| }; | |||
| class EventWaitTaskInfo : public EventTaskInfo { | |||
| public: | |||
| EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||
| ~EventWaitTaskInfo() override {} | |||
| }; | |||
| class FusionStartTaskInfo : public TaskInfo { | |||
| public: | |||
| explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} | |||
| ~FusionStartTaskInfo() override {} | |||
| }; | |||
| class FusionEndTaskInfo : public TaskInfo { | |||
| public: | |||
| explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} | |||
| ~FusionEndTaskInfo() override {} | |||
| }; | |||
| class HcclTaskInfo : public TaskInfo { | |||
| public: | |||
| HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, | |||
| void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||
| const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, | |||
| int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), | |||
| hccl_type_(hccl_type), | |||
| input_data_addr_(input_data_addr), | |||
| output_data_addr_(output_data_addr), | |||
| workspace_size_(workspace_size), | |||
| hccl_stream_num_(hccl_stream_num), | |||
| private_def_(private_def), | |||
| ops_kernel_store_(ops_kernel_store), | |||
| count_(count), | |||
| root_id_(root_id), | |||
| op_type_(op_type), | |||
| data_type_(data_type), | |||
| group_(group) {} | |||
| ~HcclTaskInfo() override {} | |||
| const std::string &hccl_type() const { return hccl_type_; } | |||
| void *input_data_addr() const { return input_data_addr_; } | |||
| void *output_data_addr() const { return output_data_addr_; } | |||
| int64_t workspace_size() const { return workspace_size_; } | |||
| int64_t hccl_stream_num() const { return hccl_stream_num_; } | |||
| const std::vector<uint8_t> &private_def() const { return private_def_; } | |||
| void *ops_kernel_store() const { return ops_kernel_store_; } | |||
| int32_t count() const { return count_; } | |||
| int64_t root_id() const { return root_id_; } | |||
| int64_t op_type() const { return op_type_; } | |||
| int64_t data_type() const { return data_type_; } | |||
| const std::string &group() const { return group_; } | |||
| private: | |||
| std::string hccl_type_; | |||
| void *input_data_addr_; | |||
| void *output_data_addr_; | |||
| int64_t workspace_size_; | |||
| int64_t hccl_stream_num_; | |||
| std::vector<uint8_t> private_def_; | |||
| void *ops_kernel_store_; | |||
| int32_t count_; | |||
| int64_t root_id_; | |||
| int64_t op_type_; | |||
| int64_t data_type_; | |||
| std::string group_; | |||
| }; | |||
| class ProfilerTraceTaskInfo : public TaskInfo { | |||
| public: | |||
| ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), | |||
| log_id_(log_id), | |||
| notify_(notify), | |||
| flat_(flat) {} | |||
| ~ProfilerTraceTaskInfo() override {} | |||
| uint64_t log_id() const { return log_id_; } | |||
| bool notify() const { return notify_; } | |||
| uint32_t flat() const { return flat_; } | |||
| private: | |||
| uint64_t log_id_; | |||
| bool notify_; | |||
| uint32_t flat_; | |||
| }; | |||
| class MemcpyAsyncTaskInfo : public TaskInfo { | |||
| public: | |||
| MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, | |||
| uint64_t count, uint32_t kind, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), | |||
| dst_(dst), | |||
| dst_max_(dst_max), | |||
| src_(src), | |||
| count_(count), | |||
| kind_(kind) {} | |||
| ~MemcpyAsyncTaskInfo() override {} | |||
| void *dst() const { return dst_; } | |||
| uint64_t dst_max() const { return dst_max_; } | |||
| void *src() const { return src_; } | |||
| uint64_t count() const { return count_; } | |||
| uint32_t kind() const { return kind_; } | |||
| private: | |||
| void *dst_; | |||
| uint64_t dst_max_; | |||
| void *src_; | |||
| uint64_t count_; | |||
| int32_t kind_; | |||
| }; | |||
| class StreamSwitchTaskInfo : public TaskInfo { | |||
| public: | |||
| StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, | |||
| void *value_addr, int64_t cond, int64_t data_type) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), | |||
| true_stream_id_(true_stream_id), | |||
| input_addr_(input_addr), | |||
| value_addr_(value_addr), | |||
| cond_(cond), | |||
| data_type_(data_type) {} | |||
| ~StreamSwitchTaskInfo() override {} | |||
| int64_t true_stream_id() const { return true_stream_id_; } | |||
| void *input_addr() const { return input_addr_; } | |||
| void *value_addr() const { return value_addr_; } | |||
| int64_t cond() const { return cond_; } | |||
| int64_t data_type() const { return data_type_; } | |||
| private: | |||
| int64_t true_stream_id_; | |||
| void *input_addr_; | |||
| void *value_addr_; | |||
| int64_t cond_; | |||
| int64_t data_type_; | |||
| }; | |||
| class StreamActiveTaskInfo : public TaskInfo { | |||
| public: | |||
| StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} | |||
| ~StreamActiveTaskInfo() override {} | |||
| uint32_t active_stream_id() const { return active_stream_id_; } | |||
| private: | |||
| uint32_t active_stream_id_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ | |||