From 4d0b6f9d9a1b212c0be65b3406ca4b538d2d4616 Mon Sep 17 00:00:00 2001 From: dingpeifei Date: Tue, 20 Apr 2021 17:06:36 +0800 Subject: [PATCH] code_sync_0420 --- ge/ge_runtime/CMakeLists.txt | 1 + ge/ge_runtime/task/label_goto_task.cc | 98 ++++++++----------- ge/ge_runtime/task/label_goto_task.h | 16 ++-- ge/ge_runtime/task/label_manager.cc | 119 ++++++++++++++++++++++++ ge/ge_runtime/task/label_manager.h | 54 +++++++++++ ge/ge_runtime/task/label_switch_task.cc | 55 +++-------- ge/ge_runtime/task/label_switch_task.h | 6 +- 7 files changed, 242 insertions(+), 107 deletions(-) create mode 100644 ge/ge_runtime/task/label_manager.cc create mode 100644 ge/ge_runtime/task/label_manager.h diff --git a/ge/ge_runtime/CMakeLists.txt b/ge/ge_runtime/CMakeLists.txt index b00dd5b3..40113285 100644 --- a/ge/ge_runtime/CMakeLists.txt +++ b/ge/ge_runtime/CMakeLists.txt @@ -16,6 +16,7 @@ set(GE_SRC_LIST "task/label_goto_task.cc" "task/label_set_task.cc" "task/label_switch_task.cc" + "task/label_manager.cc" ) add_library(ge_runtime SHARED ${GE_SRC_LIST}) diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc index ad93a98f..c04bd5cf 100644 --- a/ge/ge_runtime/task/label_goto_task.cc +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -16,99 +16,83 @@ #include "ge_runtime/task/label_goto_task.h" #include "ge_runtime/task/task_factory.h" -#include "framework/common/util.h" namespace ge { namespace model_runner { LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info) - : TaskRepeater(model_context, task_info), task_info_(task_info) { + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + index_value_(nullptr) { if (task_info_ == nullptr) { GELOGW("task_info_ is null!"); return; } auto stream_list = model_context.stream_list(); auto label_list = model_context.label_list(); + rt_model_handle_ = model_context.rt_model_handle(); uint32_t stream_id = task_info->stream_id(); - uint32_t label_id = task_info->label_id(); + label_id_ = task_info->label_id(); GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); - GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); - if (stream_id >= stream_list.size() || label_id >= label_list.size()) { + GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id_); + if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) { GELOGW("Stream/Label id invalid."); return; } stream_ = stream_list[stream_id]; - label_ = label_list[label_id]; + label_manager_ = LabelManager::GetInstance(); + if (label_manager_ == nullptr) { + GELOGW("Get label manager instance failed."); + return; + } + label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list); } LabelGotoTask::~LabelGotoTask() { - GE_FREE_RT_LOG(label_info_); - GE_FREE_RT_LOG(index_value_); + if (index_value_ != nullptr) { + rtError_t rt_ret = rtFree(index_value_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtFree index_value_ failed! ret: 0x%X.", rt_ret); + } + index_value_ = nullptr; + } } bool LabelGotoTask::Distribute() { GELOGI("LabelGotoTask Distribute start."); - if (!CheckParamValid()) { - return false; - } - - const std::vector label_list = { label_ }; - rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); - return false; - } - - uint64_t branch_index = 0; - rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); - return false; - } - - uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); - rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); - return false; - } - - rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); - return false; - } - - rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); - return false; - } - - GELOGI("DistributeTask end."); - return true; -} - -bool LabelGotoTask::CheckParamValid() { if (stream_ == nullptr) { GELOGE(PARAM_INVALID, "stream is null!"); return false; } - if (label_ == nullptr) { - GELOGE(PARAM_INVALID, "label is null!"); + if (label_info_ == nullptr) { + GELOGE(PARAM_INVALID, "label info is null!"); return false; } - if (label_info_ != nullptr) { - GELOGE(PARAM_INVALID, "label_info_ has dirty data."); - return false; + if (index_value_ == nullptr) { + rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + uint64_t index = 0; + rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } } - if (index_value_ != nullptr) { - GELOGE(PARAM_INVALID, "index_value_ has dirty data."); + void *label_info = label_info_->GetLabelInfo(); + rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; } + GELOGI("DistributeTask end."); return true; } diff --git a/ge/ge_runtime/task/label_goto_task.h b/ge/ge_runtime/task/label_goto_task.h index addbb700..e579c683 100644 --- a/ge/ge_runtime/task/label_goto_task.h +++ b/ge/ge_runtime/task/label_goto_task.h @@ -18,7 +18,11 @@ #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ #include +#include +#include +#include #include "ge_runtime/task/task.h" +#include "ge_runtime/task/label_manager.h" namespace ge { namespace model_runner { @@ -31,13 +35,13 @@ class LabelGotoTask : public TaskRepeater { bool Distribute() override; private: - bool CheckParamValid(); - std::shared_ptr task_info_; - void *stream_{nullptr}; - void *label_{nullptr}; - void *label_info_{nullptr}; - void *index_value_{nullptr}; + void *stream_; + std::shared_ptr label_info_; + void *index_value_; + uint32_t label_id_; + rtModel_t rt_model_handle_; + std::shared_ptr label_manager_; }; } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/label_manager.cc b/ge/ge_runtime/task/label_manager.cc new file mode 100644 index 00000000..a2b0c3aa --- /dev/null +++ b/ge/ge_runtime/task/label_manager.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ge_runtime/task/label_manager.h" +#include +#include +#include "runtime/mem.h" +#include "runtime/rt_model.h" +#include "common/ge_inner_error_codes.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +namespace model_runner { +std::weak_ptr LabelManager::instance_; +std::mutex LabelManager::instance_mutex_; + +template +static std::string GetVectorString(const std::vector &vec) { + std::string ret; + for (size_t i = 0; i < vec.size(); ++i) { + if (i != 0) { + ret.push_back(','); + } + ret += std::to_string(vec[i]); + } + return ret; +} + +LabelGuard::~LabelGuard() { + void *label_info = GetLabelInfo(); + if (label_info != nullptr) { + rtError_t rt_ret = rtFree(label_info); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret); + } + } +} + +std::shared_ptr LabelManager::GetInstance() { + std::lock_guard lock(instance_mutex_); + auto instance = instance_.lock(); + if (instance != nullptr) { + return instance; + } + + instance = std::make_shared(); + instance_ = instance; + return instance; +} + +std::shared_ptr LabelManager::GetLabelInfo(rtModel_t model, const std::vector &label_ids, + const std::vector &all_label) { + std::lock_guard lock(model_info_mapping_mutex_); + rtError_t rt_ret; + auto model_iter = model_info_mapping_.find(model); + if (model_iter == model_info_mapping_.end()) { + model_info_mapping_.emplace(model, std::map>()); + model_iter = model_info_mapping_.find(model); + } + + std::string label_id_str = GetVectorString(label_ids); + auto &label_map = model_iter->second; + auto label_iter = label_map.find(label_id_str); + if (label_iter != label_map.end()) { + auto label_guard = label_iter->second.lock(); + if (label_guard != nullptr) { + GELOGI("model %p find same label id %s.", model, label_id_str.c_str()); + return label_guard; + } + } + + GELOGI("Alloc label id %s for model %p.", label_id_str.c_str(), model); + void *label_info; + std::vector label_list; + bool status = true; + std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list), + [&all_label, &status](uint32_t idx) -> void * { + if (idx >= all_label.size()) { + GELOGE(PARAM_INVALID, "Invalid label id %u, all label list size %zu.", idx, all_label.size()); + status = false; + return nullptr; + } + return all_label[idx]; + }); + if (!status) { + GELOGE(PARAM_INVALID, "Get label info failed."); + return nullptr; + } + uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); + rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return nullptr; + } + + rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return nullptr; + } + + auto label_guard = std::make_shared(label_info); + label_map.emplace(label_id_str, label_guard); + return label_guard; +} +} // namespace model_runner +} // namespace ge diff --git a/ge/ge_runtime/task/label_manager.h b/ge/ge_runtime/task/label_manager.h new file mode 100644 index 00000000..f2c42c29 --- /dev/null +++ b/ge/ge_runtime/task/label_manager.h @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ +#define GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ + +#include +#include +#include +#include +#include + +namespace ge { +namespace model_runner { +class LabelGuard { + public: + explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast(label_info)) {} + ~LabelGuard(); + void *GetLabelInfo() { return reinterpret_cast(label_info_); } + + private: + uintptr_t label_info_; +}; + +class LabelManager { + public: + static std::shared_ptr GetInstance(); + std::shared_ptr GetLabelInfo(rtModel_t model, const std::vector &label_ids, + const std::vector &all_label); + + private: + std::mutex model_info_mapping_mutex_; + std::map>> model_info_mapping_; + + static std::weak_ptr instance_; + static std::mutex instance_mutex_; +}; + + +} // namespace model_runner +} // namespace ge +#endif // GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ \ No newline at end of file diff --git a/ge/ge_runtime/task/label_switch_task.cc b/ge/ge_runtime/task/label_switch_task.cc index a3c2d41a..1f913d74 100644 --- a/ge/ge_runtime/task/label_switch_task.cc +++ b/ge/ge_runtime/task/label_switch_task.cc @@ -24,14 +24,14 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, : TaskRepeater(model_context, task_info), task_info_(task_info), stream_(nullptr), - all_label_resource_(), label_info_(nullptr) { if (task_info_ == nullptr) { GELOGW("task_info_ is null!"); return; } - all_label_resource_ = model_context.label_list(); + rt_model_handle_ = model_context.rt_model_handle(); + auto all_label_resource = model_context.label_list(); auto stream_list = model_context.stream_list(); uint32_t stream_id = task_info->stream_id(); GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); @@ -40,52 +40,24 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, return; } stream_ = stream_list[stream_id]; -} - -LabelSwitchTask::~LabelSwitchTask() { - if (label_info_ != nullptr) { - rtError_t rt_ret = rtFree(label_info_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); - } - label_info_ = nullptr; + label_manager_ = LabelManager::GetInstance(); + if (label_manager_ == nullptr) { + GELOGW("Get label manager instance failed."); + return; } + label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource); } +LabelSwitchTask::~LabelSwitchTask() {} + bool LabelSwitchTask::Distribute() { GELOGI("LabelSwitchTask Distribute start."); if (!CheckParamValid()) { return false; } - const std::vector &label_index_list = task_info_->label_list(); - std::vector label_list(task_info_->label_size(), nullptr); - - for (size_t i = 0; i < task_info_->label_size(); ++i) { - uint32_t label_index = label_index_list[i]; - if (label_index >= all_label_resource_.size()) { - GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, - all_label_resource_.size()); - return false; - } - label_list[i] = all_label_resource_[label_index]; - GELOGI("Case %zu: label id %zu.", i, label_index); - } - - uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); - rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } - - rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } - - rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); + void *label_info = label_info_->GetLabelInfo(); + rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; @@ -117,8 +89,8 @@ bool LabelSwitchTask::CheckParamValid() { return false; } - if (label_info_ != nullptr) { - GELOGE(PARAM_INVALID, "label_info_ has dirty data."); + if (label_info_ == nullptr) { + GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null."); return false; } @@ -126,6 +98,5 @@ bool LabelSwitchTask::CheckParamValid() { } REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); - } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/label_switch_task.h b/ge/ge_runtime/task/label_switch_task.h index 463faa31..cfa6877c 100644 --- a/ge/ge_runtime/task/label_switch_task.h +++ b/ge/ge_runtime/task/label_switch_task.h @@ -19,6 +19,7 @@ #include #include "ge_runtime/task/task.h" +#include "ge_runtime/task/label_manager.h" namespace ge { namespace model_runner { @@ -35,8 +36,9 @@ class LabelSwitchTask : public TaskRepeater { std::shared_ptr task_info_; void *stream_; - std::vector all_label_resource_; - void *label_info_; + rtModel_t rt_model_handle_; + std::shared_ptr label_info_; + std::shared_ptr label_manager_; }; } // namespace model_runner } // namespace ge