Browse Source

code_sync_0420

tags/v1.3.0
dingpeifei 3 years ago
parent
commit
4d0b6f9d9a
7 changed files with 242 additions and 107 deletions
  1. +1
    -0
      ge/ge_runtime/CMakeLists.txt
  2. +41
    -57
      ge/ge_runtime/task/label_goto_task.cc
  3. +10
    -6
      ge/ge_runtime/task/label_goto_task.h
  4. +119
    -0
      ge/ge_runtime/task/label_manager.cc
  5. +54
    -0
      ge/ge_runtime/task/label_manager.h
  6. +13
    -42
      ge/ge_runtime/task/label_switch_task.cc
  7. +4
    -2
      ge/ge_runtime/task/label_switch_task.h

+ 1
- 0
ge/ge_runtime/CMakeLists.txt View File

@@ -16,6 +16,7 @@ set(GE_SRC_LIST
"task/label_goto_task.cc"
"task/label_set_task.cc"
"task/label_switch_task.cc"
"task/label_manager.cc"
)

add_library(ge_runtime SHARED ${GE_SRC_LIST})


+ 41
- 57
ge/ge_runtime/task/label_goto_task.cc View File

@@ -16,99 +16,83 @@

#include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h"
#include "framework/common/util.h"

namespace ge {
namespace model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) {
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
index_value_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
rt_model_handle_ = model_context.rt_model_handle();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
label_id_ = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id_);
if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
label_manager_ = LabelManager::GetInstance();
if (label_manager_ == nullptr) {
GELOGW("Get label manager instance failed.");
return;
}
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list);
}

LabelGotoTask::~LabelGotoTask() {
GE_FREE_RT_LOG(label_info_);
GE_FREE_RT_LOG(index_value_);
if (index_value_ != nullptr) {
rtError_t rt_ret = rtFree(index_value_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree index_value_ failed! ret: 0x%X.", rt_ret);
}
index_value_ = nullptr;
}
}

bool LabelGotoTask::Distribute() {
GELOGI("LabelGotoTask Distribute start.");
if (!CheckParamValid()) {
return false;
}

const std::vector<void *> label_list = { label_ };
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

uint64_t branch_index = 0;
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size();
rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

bool LabelGotoTask::CheckParamValid() {
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}

if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
if (label_info_ == nullptr) {
GELOGE(PARAM_INVALID, "label info is null!");
return false;
}

if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
if (index_value_ == nullptr) {
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

uint64_t index = 0;
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
}

if (index_value_ != nullptr) {
GELOGE(PARAM_INVALID, "index_value_ has dirty data.");
void *label_info = label_info_->GetLabelInfo();
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}



+ 10
- 6
ge/ge_runtime/task/label_goto_task.h View File

@@ -18,7 +18,11 @@
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_

#include <memory>
#include <vector>
#include <map>
#include <mutex>
#include "ge_runtime/task/task.h"
#include "ge_runtime/task/label_manager.h"

namespace ge {
namespace model_runner {
@@ -31,13 +35,13 @@ class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
bool Distribute() override;

private:
bool CheckParamValid();

std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_{nullptr};
void *label_{nullptr};
void *label_info_{nullptr};
void *index_value_{nullptr};
void *stream_;
std::shared_ptr<LabelGuard> label_info_;
void *index_value_;
uint32_t label_id_;
rtModel_t rt_model_handle_;
std::shared_ptr<LabelManager> label_manager_;
};
} // namespace model_runner
} // namespace ge


+ 119
- 0
ge/ge_runtime/task/label_manager.cc View File

@@ -0,0 +1,119 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_manager.h"
#include <algorithm>
#include <string>
#include "runtime/mem.h"
#include "runtime/rt_model.h"
#include "common/ge_inner_error_codes.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
namespace model_runner {
std::weak_ptr<LabelManager> LabelManager::instance_;
std::mutex LabelManager::instance_mutex_;

template <class T>
static std::string GetVectorString(const std::vector<T> &vec) {
std::string ret;
for (size_t i = 0; i < vec.size(); ++i) {
if (i != 0) {
ret.push_back(',');
}
ret += std::to_string(vec[i]);
}
return ret;
}

LabelGuard::~LabelGuard() {
void *label_info = GetLabelInfo();
if (label_info != nullptr) {
rtError_t rt_ret = rtFree(label_info);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret);
}
}
}

std::shared_ptr<LabelManager> LabelManager::GetInstance() {
std::lock_guard<std::mutex> lock(instance_mutex_);
auto instance = instance_.lock();
if (instance != nullptr) {
return instance;
}

instance = std::make_shared<LabelManager>();
instance_ = instance;
return instance;
}

std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids,
const std::vector<void *> &all_label) {
std::lock_guard<std::mutex> lock(model_info_mapping_mutex_);
rtError_t rt_ret;
auto model_iter = model_info_mapping_.find(model);
if (model_iter == model_info_mapping_.end()) {
model_info_mapping_.emplace(model, std::map<std::string, std::weak_ptr<LabelGuard>>());
model_iter = model_info_mapping_.find(model);
}

std::string label_id_str = GetVectorString(label_ids);
auto &label_map = model_iter->second;
auto label_iter = label_map.find(label_id_str);
if (label_iter != label_map.end()) {
auto label_guard = label_iter->second.lock();
if (label_guard != nullptr) {
GELOGI("model %p find same label id %s.", model, label_id_str.c_str());
return label_guard;
}
}

GELOGI("Alloc label id %s for model %p.", label_id_str.c_str(), model);
void *label_info;
std::vector<void *> label_list;
bool status = true;
std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list),
[&all_label, &status](uint32_t idx) -> void * {
if (idx >= all_label.size()) {
GELOGE(PARAM_INVALID, "Invalid label id %u, all label list size %zu.", idx, all_label.size());
status = false;
return nullptr;
}
return all_label[idx];
});
if (!status) {
GELOGE(PARAM_INVALID, "Get label info failed.");
return nullptr;
}
uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size();
rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return nullptr;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return nullptr;
}

auto label_guard = std::make_shared<LabelGuard>(label_info);
label_map.emplace(label_id_str, label_guard);
return label_guard;
}
} // namespace model_runner
} // namespace ge

+ 54
- 0
ge/ge_runtime/task/label_manager.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_
#define GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_

#include <vector>
#include <memory>
#include <mutex>
#include <map>
#include <runtime/base.h>

namespace ge {
namespace model_runner {
class LabelGuard {
public:
explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast<uintptr_t>(label_info)) {}
~LabelGuard();
void *GetLabelInfo() { return reinterpret_cast<void *>(label_info_); }

private:
uintptr_t label_info_;
};

class LabelManager {
public:
static std::shared_ptr<LabelManager> GetInstance();
std::shared_ptr<LabelGuard> GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids,
const std::vector<void *> &all_label);

private:
std::mutex model_info_mapping_mutex_;
std::map<rtModel_t, std::map<std::string, std::weak_ptr<LabelGuard>>> model_info_mapping_;

static std::weak_ptr<LabelManager> instance_;
static std::mutex instance_mutex_;
};


} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_

+ 13
- 42
ge/ge_runtime/task/label_switch_task.cc View File

@@ -24,14 +24,14 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
all_label_resource_(),
label_info_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}

all_label_resource_ = model_context.label_list();
rt_model_handle_ = model_context.rt_model_handle();
auto all_label_resource = model_context.label_list();
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
@@ -40,52 +40,24 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
return;
}
stream_ = stream_list[stream_id];
}

LabelSwitchTask::~LabelSwitchTask() {
if (label_info_ != nullptr) {
rtError_t rt_ret = rtFree(label_info_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret);
}
label_info_ = nullptr;
label_manager_ = LabelManager::GetInstance();
if (label_manager_ == nullptr) {
GELOGW("Get label manager instance failed.");
return;
}
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource);
}

LabelSwitchTask::~LabelSwitchTask() {}

bool LabelSwitchTask::Distribute() {
GELOGI("LabelSwitchTask Distribute start.");
if (!CheckParamValid()) {
return false;
}

const std::vector<uint32_t> &label_index_list = task_info_->label_list();
std::vector<void *> label_list(task_info_->label_size(), nullptr);

for (size_t i = 0; i < task_info_->label_size(); ++i) {
uint32_t label_index = label_index_list[i];
if (label_index >= all_label_resource_.size()) {
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index,
all_label_resource_.size());
return false;
}
label_list[i] = all_label_resource_[label_index];
GELOGI("Case %zu: label id %zu.", i, label_index);
}

uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size();
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_);
void *label_info = label_info_->GetLabelInfo();
rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
@@ -117,8 +89,8 @@ bool LabelSwitchTask::CheckParamValid() {
return false;
}

if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
if (label_info_ == nullptr) {
GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null.");
return false;
}

@@ -126,6 +98,5 @@ bool LabelSwitchTask::CheckParamValid() {
}

REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo);

} // namespace model_runner
} // namespace ge

+ 4
- 2
ge/ge_runtime/task/label_switch_task.h View File

@@ -19,6 +19,7 @@

#include <memory>
#include "ge_runtime/task/task.h"
#include "ge_runtime/task/label_manager.h"

namespace ge {
namespace model_runner {
@@ -35,8 +36,9 @@ class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> {

std::shared_ptr<LabelSwitchTaskInfo> task_info_;
void *stream_;
std::vector<void *> all_label_resource_;
void *label_info_;
rtModel_t rt_model_handle_;
std::shared_ptr<LabelGuard> label_info_;
std::shared_ptr<LabelManager> label_manager_;
};
} // namespace model_runner
} // namespace ge


Loading…
Cancel
Save