| @@ -16,6 +16,7 @@ set(GE_SRC_LIST | |||
| "task/label_goto_task.cc" | |||
| "task/label_set_task.cc" | |||
| "task/label_switch_task.cc" | |||
| "task/label_manager.cc" | |||
| ) | |||
| add_library(ge_runtime SHARED ${GE_SRC_LIST}) | |||
| @@ -19,14 +19,10 @@ | |||
| namespace ge { | |||
| namespace model_runner { | |||
| std::weak_ptr<LabelGotoTask::LabelManager> LabelGotoTask::LabelManager::instance_; | |||
| std::mutex LabelGotoTask::LabelManager::instance_mutex_; | |||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr), | |||
| index_value_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| @@ -44,13 +40,12 @@ LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::share | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id_]; | |||
| label_manager_ = LabelManager::GetInstance(); | |||
| if (label_manager_ == nullptr) { | |||
| GELOGW("Get label manager instance failed."); | |||
| return; | |||
| } | |||
| label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, label_id_, label_); | |||
| label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list); | |||
| } | |||
| LabelGotoTask::~LabelGotoTask() { | |||
| @@ -69,10 +64,6 @@ bool LabelGotoTask::Distribute() { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| if (label_info_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label info is null!"); | |||
| @@ -105,69 +96,6 @@ bool LabelGotoTask::Distribute() { | |||
| return true; | |||
| } | |||
| LabelGotoTask::LabelGuard::~LabelGuard() { | |||
| void *label_info = GetLabelInfo(); | |||
| if (label_info != nullptr) { | |||
| rtError_t rt_ret = rtFree(label_info); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<LabelGotoTask::LabelManager> LabelGotoTask::LabelManager::GetInstance() { | |||
| std::lock_guard<std::mutex> lock(instance_mutex_); | |||
| auto instance = instance_.lock(); | |||
| if (instance != nullptr) { | |||
| return instance; | |||
| } | |||
| instance = std::make_shared<LabelManager>(); | |||
| instance_ = instance; | |||
| return instance; | |||
| } | |||
| std::shared_ptr<LabelGotoTask::LabelGuard> LabelGotoTask::LabelManager::GetLabelInfo(rtModel_t model, uint32_t label_id, | |||
| void *label) { | |||
| std::lock_guard<std::mutex> lock(model_info_mapping_mutex_); | |||
| rtError_t rt_ret; | |||
| auto model_iter = model_info_mapping_.find(model); | |||
| if (model_iter == model_info_mapping_.end()) { | |||
| model_info_mapping_.emplace(model, std::map<uint32_t, std::weak_ptr<LabelGuard>>()); | |||
| model_iter = model_info_mapping_.find(model); | |||
| } | |||
| std::map<uint32_t, std::weak_ptr<LabelGuard>> &label_map = model_iter->second; | |||
| auto label_iter = label_map.find(label_id); | |||
| if (label_iter != label_map.end()) { | |||
| auto label_guard = label_iter->second.lock(); | |||
| if (label_guard != nullptr) { | |||
| GELOGI("model %p find same label id.", model, label_id); | |||
| return label_guard; | |||
| } | |||
| } | |||
| GELOGI("Alloc label id %u for model %p.", label_id, model); | |||
| void *label_info; | |||
| std::vector<void *> label_list = {label}; | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||
| rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return nullptr; | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return nullptr; | |||
| } | |||
| auto label_guard = std::make_shared<LabelGuard>(label_info); | |||
| label_map.emplace(label_id, label_guard); | |||
| return label_guard; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -22,6 +22,7 @@ | |||
| #include <map> | |||
| #include <mutex> | |||
| #include "ge_runtime/task/task.h" | |||
| #include "ge_runtime/task/label_manager.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| @@ -34,41 +35,14 @@ class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
| bool Distribute() override; | |||
| private: | |||
| class LabelGuard; | |||
| class LabelManager; | |||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| std::shared_ptr<LabelGuard> label_info_; | |||
| void *index_value_; | |||
| uint32_t label_id_; | |||
| rtModel_t rt_model_handle_; | |||
| std::shared_ptr<LabelManager> label_manager_; | |||
| }; | |||
| class LabelGotoTask::LabelGuard { | |||
| public: | |||
| explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast<uintptr_t>(label_info)) {} | |||
| ~LabelGuard(); | |||
| void *GetLabelInfo() { return reinterpret_cast<void *>(label_info_); } | |||
| private: | |||
| uintptr_t label_info_; | |||
| }; | |||
| class LabelGotoTask::LabelManager { | |||
| public: | |||
| static std::shared_ptr<LabelManager> GetInstance(); | |||
| std::shared_ptr<LabelGuard> GetLabelInfo(rtModel_t model, uint32_t label_id, void *label); | |||
| private: | |||
| std::mutex model_info_mapping_mutex_; | |||
| std::map<rtModel_t, std::map<uint32_t, std::weak_ptr<LabelGuard>>> model_info_mapping_; | |||
| static std::weak_ptr<LabelManager> instance_; | |||
| static std::mutex instance_mutex_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,119 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_manager.h" | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include "runtime/mem.h" | |||
| #include "runtime/rt_model.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| std::weak_ptr<LabelManager> LabelManager::instance_; | |||
| std::mutex LabelManager::instance_mutex_; | |||
| template <class T> | |||
| static std::string GetVectorString(const std::vector<T> &vec) { | |||
| std::string ret; | |||
| for (size_t i = 0; i < vec.size(); ++i) { | |||
| if (i != 0) { | |||
| ret.push_back(','); | |||
| } | |||
| ret += std::to_string(vec[i]); | |||
| } | |||
| return ret; | |||
| } | |||
| LabelGuard::~LabelGuard() { | |||
| void *label_info = GetLabelInfo(); | |||
| if (label_info != nullptr) { | |||
| rtError_t rt_ret = rtFree(label_info); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<LabelManager> LabelManager::GetInstance() { | |||
| std::lock_guard<std::mutex> lock(instance_mutex_); | |||
| auto instance = instance_.lock(); | |||
| if (instance != nullptr) { | |||
| return instance; | |||
| } | |||
| instance = std::make_shared<LabelManager>(); | |||
| instance_ = instance; | |||
| return instance; | |||
| } | |||
| std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids, | |||
| const std::vector<void *> &all_label) { | |||
| std::lock_guard<std::mutex> lock(model_info_mapping_mutex_); | |||
| rtError_t rt_ret; | |||
| auto model_iter = model_info_mapping_.find(model); | |||
| if (model_iter == model_info_mapping_.end()) { | |||
| model_info_mapping_.emplace(model, std::map<std::string, std::weak_ptr<LabelGuard>>()); | |||
| model_iter = model_info_mapping_.find(model); | |||
| } | |||
| std::string label_id_str = GetVectorString(label_ids); | |||
| auto &label_map = model_iter->second; | |||
| auto label_iter = label_map.find(label_id_str); | |||
| if (label_iter != label_map.end()) { | |||
| auto label_guard = label_iter->second.lock(); | |||
| if (label_guard != nullptr) { | |||
| GELOGI("model %p find same label id %s.", model, label_id_str.c_str()); | |||
| return label_guard; | |||
| } | |||
| } | |||
| GELOGI("Alloc label id %s for model %p.", label_id_str.c_str(), model); | |||
| void *label_info; | |||
| std::vector<void *> label_list; | |||
| bool status = true; | |||
| std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list), | |||
| [&all_label, &status](uint32_t idx) -> void * { | |||
| if (idx >= all_label.size()) { | |||
| GELOGE(PARAM_INVALID, "Invalid label id %u, all label list size %zu.", idx, all_label.size()); | |||
| status = false; | |||
| return nullptr; | |||
| } | |||
| return all_label[idx]; | |||
| }); | |||
| if (!status) { | |||
| GELOGE(PARAM_INVALID, "Get label info failed."); | |||
| return nullptr; | |||
| } | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||
| rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return nullptr; | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return nullptr; | |||
| } | |||
| auto label_guard = std::make_shared<LabelGuard>(label_info); | |||
| label_map.emplace(label_id_str, label_guard); | |||
| return label_guard; | |||
| } | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <map> | |||
| #include <runtime/base.h> | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelGuard { | |||
| public: | |||
| explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast<uintptr_t>(label_info)) {} | |||
| ~LabelGuard(); | |||
| void *GetLabelInfo() { return reinterpret_cast<void *>(label_info_); } | |||
| private: | |||
| uintptr_t label_info_; | |||
| }; | |||
| class LabelManager { | |||
| public: | |||
| static std::shared_ptr<LabelManager> GetInstance(); | |||
| std::shared_ptr<LabelGuard> GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids, | |||
| const std::vector<void *> &all_label); | |||
| private: | |||
| std::mutex model_info_mapping_mutex_; | |||
| std::map<rtModel_t, std::map<std::string, std::weak_ptr<LabelGuard>>> model_info_mapping_; | |||
| static std::weak_ptr<LabelManager> instance_; | |||
| static std::mutex instance_mutex_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "ge_runtime/task/label_switch_task.h" | |||
| #include <vector> | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| @@ -25,14 +24,14 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
| : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| all_label_resource_(), | |||
| label_info_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| all_label_resource_ = model_context.label_list(); | |||
| rt_model_handle_ = model_context.rt_model_handle(); | |||
| auto all_label_resource = model_context.label_list(); | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| @@ -41,31 +40,24 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| CopyLabelList(); | |||
| } | |||
| LabelSwitchTask::~LabelSwitchTask() { | |||
| if (label_info_ != nullptr) { | |||
| rtError_t rt_ret = rtFree(label_info_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||
| } | |||
| label_info_ = nullptr; | |||
| label_manager_ = LabelManager::GetInstance(); | |||
| if (label_manager_ == nullptr) { | |||
| GELOGW("Get label manager instance failed."); | |||
| return; | |||
| } | |||
| label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource); | |||
| } | |||
| LabelSwitchTask::~LabelSwitchTask() {} | |||
| bool LabelSwitchTask::Distribute() { | |||
| GELOGI("LabelSwitchTask Distribute start."); | |||
| if (!CheckParamValid()) { | |||
| return false; | |||
| } | |||
| if (label_info_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null."); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info_, stream_); | |||
| void *label_info = label_info_->GetLabelInfo(); | |||
| rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| @@ -97,48 +89,14 @@ bool LabelSwitchTask::CheckParamValid() { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void LabelSwitchTask::CopyLabelList() { | |||
| if (!CheckParamValid()) { | |||
| return; | |||
| } | |||
| if (label_info_ != nullptr) { | |||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
| return; | |||
| } | |||
| const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||
| std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||
| for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||
| uint32_t label_index = label_index_list[i]; | |||
| if (label_index >= all_label_resource_.size()) { | |||
| GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||
| all_label_resource_.size()); | |||
| return; | |||
| } | |||
| label_list[i] = all_label_resource_[label_index]; | |||
| GELOGI("Case %zu: label id %zu.", i, label_index); | |||
| } | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||
| rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return; | |||
| if (label_info_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null."); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return; | |||
| } | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| #include "ge_runtime/task/label_manager.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| @@ -32,12 +33,12 @@ class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||
| private: | |||
| bool CheckParamValid(); | |||
| void CopyLabelList(); | |||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||
| void *stream_; | |||
| std::vector<void *> all_label_resource_; | |||
| void *label_info_; | |||
| rtModel_t rt_model_handle_; | |||
| std::shared_ptr<LabelGuard> label_info_; | |||
| std::shared_ptr<LabelManager> label_manager_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||