Browse Source

remove ge_runtime

pull/1909/head
lujiale 4 years ago
parent
commit
d753fdf159
42 changed files with 0 additions and 3992 deletions
  1. +0
    -1
      ge/CMakeLists.txt
  2. +0
    -78
      ge/ge_runtime/CMakeLists.txt
  3. +0
    -62
      ge/ge_runtime/model_context.h
  4. +0
    -173
      ge/ge_runtime/model_runner.cc
  5. +0
    -66
      ge/ge_runtime/module.mk
  6. +0
    -94
      ge/ge_runtime/output.cc
  7. +0
    -53
      ge/ge_runtime/output.h
  8. +0
    -27
      ge/ge_runtime/proto/task.pb.h
  9. +0
    -547
      ge/ge_runtime/runtime_model.cc
  10. +0
    -92
      ge/ge_runtime/runtime_model.h
  11. +0
    -168
      ge/ge_runtime/task/aicpu_task.cc
  12. +0
    -50
      ge/ge_runtime/task/aicpu_task.h
  13. +0
    -160
      ge/ge_runtime/task/cce_task.cc
  14. +0
    -47
      ge/ge_runtime/task/cce_task.h
  15. +0
    -61
      ge/ge_runtime/task/event_record_task.cc
  16. +0
    -41
      ge/ge_runtime/task/event_record_task.h
  17. +0
    -67
      ge/ge_runtime/task/event_wait_task.cc
  18. +0
    -41
      ge/ge_runtime/task/event_wait_task.h
  19. +0
    -268
      ge/ge_runtime/task/hccl_task.cc
  20. +0
    -69
      ge/ge_runtime/task/hccl_task.h
  21. +0
    -117
      ge/ge_runtime/task/label_goto_task.cc
  22. +0
    -45
      ge/ge_runtime/task/label_goto_task.h
  23. +0
    -70
      ge/ge_runtime/task/label_set_task.cc
  24. +0
    -41
      ge/ge_runtime/task/label_set_task.h
  25. +0
    -131
      ge/ge_runtime/task/label_switch_task.cc
  26. +0
    -44
      ge/ge_runtime/task/label_switch_task.h
  27. +0
    -57
      ge/ge_runtime/task/memcpy_async_task.cc
  28. +0
    -40
      ge/ge_runtime/task/memcpy_async_task.h
  29. +0
    -55
      ge/ge_runtime/task/profiler_task.cc
  30. +0
    -40
      ge/ge_runtime/task/profiler_task.h
  31. +0
    -60
      ge/ge_runtime/task/stream_active_task.cc
  32. +0
    -41
      ge/ge_runtime/task/stream_active_task.h
  33. +0
    -82
      ge/ge_runtime/task/stream_switch_task.cc
  34. +0
    -43
      ge/ge_runtime/task/stream_switch_task.h
  35. +0
    -58
      ge/ge_runtime/task/task.h
  36. +0
    -87
      ge/ge_runtime/task/task_factory.h
  37. +0
    -112
      ge/ge_runtime/task/tbe_task.cc
  38. +0
    -46
      ge/ge_runtime/task/tbe_task.h
  39. +0
    -113
      inc/framework/ge_runtime/davinci_model.h
  40. +0
    -68
      inc/framework/ge_runtime/model_runner.h
  41. +0
    -72
      inc/framework/ge_runtime/op_info.h
  42. +0
    -405
      inc/framework/ge_runtime/task_info.h

+ 0
- 1
ge/CMakeLists.txt View File

@@ -6,7 +6,6 @@ if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES)
add_subdirectory(offline) add_subdirectory(offline)
elseif (ENABLE_D) elseif (ENABLE_D)
add_subdirectory(common) add_subdirectory(common)
add_subdirectory(ge_runtime)
endif () endif ()


set(GRAPHENGINE_PROTO_LIST set(GRAPHENGINE_PROTO_LIST


+ 0
- 78
ge/ge_runtime/CMakeLists.txt View File

@@ -1,78 +0,0 @@
############ libge_runtime.so ############
set(GE_SRC_LIST
"model_runner.cc"
"runtime_model.cc"
"output.cc"
"task/aicpu_task.cc"
"task/cce_task.cc"
"task/tbe_task.cc"
"task/event_record_task.cc"
"task/event_wait_task.cc"
"task/stream_active_task.cc"
"task/stream_switch_task.cc"
"task/hccl_task.cc"
"task/memcpy_async_task.cc"
"task/profiler_task.cc"
"task/label_goto_task.cc"
"task/label_set_task.cc"
"task/label_switch_task.cc"
)

add_library(ge_runtime SHARED ${GE_SRC_LIST})

target_compile_options(ge_runtime PRIVATE
-Werror
-O2
-Wno-deprecated-declarations
-fno-common
)

target_compile_definitions(ge_runtime PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
LOG_CPP
)

target_include_directories(ge_runtime PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${GE_CODE_DIR}
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/graph
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${GE_CODE_DIR}/inc/framework/common
${GE_CODE_DIR}/inc/framework/ge_runtime
${GE_CODE_DIR}/inc/cce
${GE_CODE_DIR}/third_party/fwkacllib/inc
${METADEF_DIR}
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
)

target_link_options(ge_runtime PRIVATE
-Wl,-Bsymbolic
)

target_link_libraries(ge_runtime PRIVATE
$<BUILD_INTERFACE:intf_pub>
-Wl,--no-as-needed
slog
runtime
c_sec
graph
-Wl,--as-needed
-lrt
-ldl
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(TARGETS ge_runtime OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
)

+ 0
- 62
ge/ge_runtime/model_context.h View File

@@ -1,62 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_MODEL_CONTEXT_H_
#define GE_GE_RUNTIME_MODEL_CONTEXT_H_

#include <vector>
#include "runtime/rt_model.h"

namespace ge {
namespace model_runner {
class ModelContext {
public:
ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle,
rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list,
const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list)
: device_id_(device_id),
session_id_(session_id),
priority_(priority),
rt_model_handle_(rt_model_handle),
rt_model_stream_(rt_model_stream),
stream_list_(stream_list),
label_list_(label_list),
event_list_(event_list) {}
~ModelContext() {}

uint64_t device_id() const { return device_id_; }
uint64_t session_id() const { return session_id_; }
int32_t priority() const { return priority_; }
const rtModel_t &rt_model_handle() const { return rt_model_handle_; }
const rtStream_t &rt_model_stream() const { return rt_model_stream_; }
const std::vector<rtStream_t> &stream_list() const { return stream_list_; }
const std::vector<rtLabel_t> &label_list() const { return label_list_; }
const std::vector<rtEvent_t> &event_list() const { return event_list_; }

private:
uint32_t device_id_;
uint64_t session_id_;
int32_t priority_;
rtModel_t rt_model_handle_;
rtStream_t rt_model_stream_;
std::vector<rtStream_t> stream_list_;
std::vector<rtLabel_t> label_list_;
std::vector<rtEvent_t> event_list_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_MODEL_CONTEXT_H_

+ 0
- 173
ge/ge_runtime/model_runner.cc View File

@@ -1,173 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/model_runner.h"
#include "./runtime_model.h"
#include "framework/common/debug/ge_log.h"
#include "common/ge_inner_error_codes.h"
#include "common/ge/ge_util.h"
#include "ge_runtime/davinci_model.h"
#include "graph/op_desc.h"

namespace ge {
namespace model_runner {

using RuntimeModelPtr = std::shared_ptr<RuntimeModel>;
using DavinciModelPtr = std::shared_ptr<DavinciModel>;

ModelRunner &ModelRunner::Instance() {
static ModelRunner instance; // Guaranteed to be destroyed.
return instance;
}

bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
std::shared_ptr<DavinciModel> davinci_model,
std::shared_ptr<ModelListener> listener) {
std::shared_ptr<RuntimeModel> model = MakeShared<RuntimeModel>();
if (model == nullptr) {
return false;
}
bool status = model->Load(device_id, session_id, davinci_model);
if (!status) {
return false;
}

runtime_models_[model_id] = model;
return true;
}

bool ModelRunner::DistributeTask(uint32_t model_id) {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
return false;
}
return model_iter->second->DistributeTask();
}

bool ModelRunner::LoadModelComplete(uint32_t model_id) {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
return false;
}
return model_iter->second->LoadComplete();
}

const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
static const std::vector<uint32_t> empty_ret;
return empty_ret;
}

return model_iter->second->GetTaskIdList();
}

const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
static const std::vector<uint32_t> empty_ret;
return empty_ret;
}

return model_iter->second->GetStreamIdList();
}

const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGW("Model id %u not found.", model_id);
static const std::map<std::string, std::shared_ptr<RuntimeInfo>> empty_ret;
return empty_ret;
}

return model_iter->second->GetRuntimeInfoMap();
}

void *ModelRunner::GetModelHandle(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGW("Model id %u not found.", model_id);
return nullptr;
}

return model_iter->second->GetModelHandle();
}

bool ModelRunner::UnloadModel(uint32_t model_id) {
auto iter = runtime_models_.find(model_id);
if (iter != runtime_models_.end()) {
(void)runtime_models_.erase(iter);
return true;
}

return false;
}

bool ModelRunner::RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data) {
if (output_data == nullptr) {
GELOGW("Output data point is null.");
}

auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
return false;
}

bool status = model_iter->second->CopyInputData(input_data);
if (!status) {
GELOGE(FAILED, "Copy input data fail.");
return false;
}

status = model_iter->second->Run();
if (!status) {
GELOGE(FAILED, "Run model fail.");
return false;
}

return true;
}

bool ModelRunner::GetInputOutputDescInfo(uint32_t model_id, bool zero_copy,
std::vector<InputOutputDescInfo> *input_desc,
std::vector<InputOutputDescInfo> *output_desc,
std::vector<uint32_t> *input_format, std::vector<uint32_t> *output_format) {
if (runtime_models_.find(model_id) == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
return false;
}

auto model = runtime_models_[model_id];
if (input_desc == nullptr || output_desc == nullptr) {
GELOGE(PARAM_INVALID, "input_desc or output_desc is null.");
return false;
}

bool status = model->GetInputOutputDescInfo(zero_copy, input_desc, output_desc, input_format, output_format);
if (!status) {
GELOGE(FAILED, "Get input output desc info fail.");
return false;
}

return true;
}
} // namespace model_runner
} // namespace ge

+ 0
- 66
ge/ge_runtime/module.mk View File

@@ -1,66 +0,0 @@
LOCAL_PATH := $(call my-dir)

# task.proto is old task, add it for ops_kernel_info_store
local_ge_runtime_src_files := \
model_runner.cc \
runtime_model.cc \
output.cc \
task/aicpu_task.cc \
task/cce_task.cc \
task/tbe_task.cc \
task/event_record_task.cc \
task/event_wait_task.cc \
task/stream_active_task.cc \
task/stream_switch_task.cc \
task/hccl_task.cc \
task/memcpy_async_task.cc \
task/profiler_task.cc \

local_ge_runtime_include := \
$(LOCAL_PATH)/ \
$(TOPDIR)libc_sec/include \
$(TOPDIR)inc/external \
$(TOPDIR)inc/external/graph \
$(TOPDIR)inc/framework \
$(TOPDIR)inc/graph \
$(TOPDIR)inc \
$(LOCAL_PATH)/../ \
third_party/protobuf/include

local_ge_runtime_shared_library := \
libruntime \
libslog \
libc_sec

local_ge_runtime_ldflags := -lrt -ldl

# compile device libge_runtime
include $(CLEAR_VARS)

LOCAL_MODULE := libge_runtime
LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2
LOCAL_CFLAGS += -Werror
LOCAL_SRC_FILES := $(local_ge_runtime_src_files)
LOCAL_C_INCLUDES := $(local_ge_runtime_include)
LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library)
LOCAL_LDFLAGS += $(local_ge_runtime_ldflags)

include $(BUILD_SHARED_LIBRARY)

# compile host libge_runtime
include $(CLEAR_VARS)

LOCAL_MODULE := libge_runtime
LOCAL_CFLAGS += -Werror
LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0
ifeq ($(DEBUG), 1)
LOCAL_CFLAGS += -g -O0
else
LOCAL_CFLAGS += -O2
endif
LOCAL_SRC_FILES := $(local_ge_runtime_src_files)
LOCAL_C_INCLUDES := $(local_ge_runtime_include)
LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library)
LOCAL_LDFLAGS += $(local_ge_runtime_ldflags)

include $(BUILD_HOST_SHARED_LIBRARY)

+ 0
- 94
ge/ge_runtime/output.cc View File

@@ -1,94 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/output.h"
#include "common/ge_inner_error_codes.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
namespace model_runner {
Output::Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model)
: model_(model), op_info_(op_info), input_num_(0) {}

Output::~Output() {}

bool Output::Init() {
if (op_info_ == nullptr || model_ == nullptr) {
GELOGE(INTERNAL_ERROR, "The op_desc_ or model_ is nullptr.");
return false;
}

input_num_ = op_info_->input_tensors.size();
v_input_size_.clear();
v_input_data_addr_.clear();

auto input_vector = op_info_->input_addrs;
if (input_num_ != input_vector.size()) {
GELOGE(INTERNAL_ERROR, "The input desc size: %zu != input addr size: %zu.", input_num_, input_vector.size());
return false;
}

for (size_t i = 0; i < input_num_; i++) {
uint32_t tensorSize = 0;
const auto &input_info = op_info_->input_tensors.at(i);
tensorSize = input_info.size;
v_input_size_.push_back(tensorSize);
v_input_data_addr_.push_back(reinterpret_cast<uint8_t *>(input_vector.at(i)));
}

GELOGI("Init output:%zu, %zu, %zu", input_num_, v_input_size_.size(), v_input_data_addr_.size());

return true;
}

///
/// @ingroup domi_ome
/// @brief Copy Op Output to user space.
/// @brief when model running, Add one DataOp as input node, Add one Output Op as output node.
/// @return Status
///
bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share) {
if (rslt == nullptr) {
GELOGE(FAILED, "OutputData is null.");
return false;
}
uint32_t data_count = 0;
if (v_input_size_.empty() || v_input_data_addr_.empty()) {
GELOGE(INTERNAL_ERROR, "v_output_size_ or v_output_data_addr_ is empty!");
return false;
}

for (size_t i = 0; i < input_num_; i++) {
DataBuffer data_buf = rslt->blobs[data_begin + data_count];
bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share);
if (!ret) {
GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]);
return ret;
}
data_index = data_begin + data_count;
}

return true;
}

bool Output::SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i,
bool support_mem_share) {
return true;
}

} // namespace model_runner
} // namespace ge

+ 0
- 53
ge/ge_runtime/output.h View File

@@ -1,53 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_OUTPUT_H_
#define GE_GE_RUNTIME_OUTPUT_H_

#include <memory>
#include <vector>
#include "ge_runtime/davinci_model.h"
#include "common/ge_types.h"

namespace ge {
namespace model_runner {

class Output {
public:
Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model);
virtual ~Output();
bool Init();

bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share);

bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, bool support_mem_share);

// Copy assignment operator and copy constructor are deleted
Output &operator=(const Output &output) = delete;
Output(const Output &output) = delete;

protected:
std::shared_ptr<DavinciModel> model_;
OpInfoPtr op_info_;

// Input descriptions
size_t input_num_;
vector<void *> v_input_data_addr_; // Init as:buf_base + op_def_->input(i));
vector<uint32_t> v_input_size_;
};
} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_OUTPUT_H_

+ 0
- 27
ge/ge_runtime/proto/task.pb.h View File

@@ -1,27 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: task.proto

#ifndef STUB_TASK_PROTO_H
#define STUB_TASK_PROTO_H

namespace domi {
class TaskDef;
}

#endif // STUB_TASK_PROTO_H

+ 0
- 547
ge/ge_runtime/runtime_model.cc View File

@@ -1,547 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/runtime_model.h"
#include <set>
#include "./model_context.h"
#include "./task/task.h"
#include "common/ge_inner_error_codes.h"
#include "common/types.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/op/op_parser_util.h"
#include "graph/types.h"
#include "task/task_factory.h"
#include "ge/common/math/math_util.h"

namespace ge {
namespace model_runner {
namespace {
const int kOffsetUnit = 8;
const uint32_t kStringHeadElems = 2;
} // namespace
RuntimeModel::~RuntimeModel() {
GELOGI("RuntimeModel destructor start");

// Unbind rtModel from all task related streams
RtModelUnbindStream();

// Release task first, hccl task hold stream
task_list_.clear();

// Release all task related streams
RtStreamDestory();

// Release rtlabel resource
RtLabelDestory();

// Release rtEvent resourece
RtEventDestory();

GELOGI("Do RtModelDestory");
// Release all rt_model
RtModelDestory();
}

bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) {
if (davinci_model == nullptr) {
GELOGE(PARAM_INVALID, "Davinci model is null.");
return false;
}

std::set<int64_t> wait_active_streams;
std::set<int64_t> force_copy_streams;

for (const auto &stream_id : davinci_model->GetWaitActiveStreams()) {
GELOGI("stream id %u is wait active stream.", stream_id);
(void)wait_active_streams.insert(stream_id);
}

for (const auto &stream_id : davinci_model->GetForceCopyStreams()) {
GELOGI("stream id %u is force copy stream.", stream_id);
(void)force_copy_streams.insert(stream_id);
}

GELOGI("stream number:%u", davinci_model->GetStreamNum());
for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) {
rtStream_t stream = nullptr;
uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end())
? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY)
: (RT_STREAM_PERSISTENT);

rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtStreamCreate failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("rtStreamCreateWithFlags end.");

stream_list_.emplace_back(stream);

// Bind rt_model_handle_ to all task related streams
flag = (wait_active_streams.find(i) != wait_active_streams.end()) ? (static_cast<uint32_t>(RT_INVALID_FLAG))
: (static_cast<uint32_t>(RT_HEAD_STREAM));
rt_ret = rtModelBindStream(rt_model_handle_, stream, flag);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtModelBindStream failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("stream index:%u, stream:%p.", i, stream);
}

return true;
}

bool RuntimeModel::InitEvent(uint32_t event_num) {
GELOGI("event number:%u.", event_num);
for (uint32_t i = 0; i < event_num; ++i) {
rtEvent_t rt_event;
rtError_t rt_ret = rtEventCreate(&rt_event);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtEventCreate failed, i; %u; ret: 0x%X", i, rt_ret);
return false;
}
event_list_.push_back(rt_event);
}
return true;
}

bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) {
GELOGI("batch number:%u.", davinci_model->GetBatchNum());
label_list_.resize(davinci_model->GetBatchNum());
for (auto &task_info : davinci_model->GetTaskInfoList()) {
if (task_info == nullptr) {
GELOGE(PARAM_INVALID, "task_info is null.");
continue;
}

if (task_info->type() != TaskInfoType::LABEL_SET) {
continue;
}
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info);

if (label_set_task_info->stream_id() >= stream_list_.size()) {
GELOGE(PARAM_INVALID, "Invalid stream id.");
return false;
}

rtLabel_t rt_label = nullptr;
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret);
return false;
}
label_list_[label_set_task_info->label_id()] = rt_label;
}

return true;
}

bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) {
GELOGI("InitResource start");
if (davinci_model == nullptr) {
GELOGE(PARAM_INVALID, "davinci model is null");
return false;
}
rtError_t rt_ret = rtModelCreate(&rt_model_handle_, 0);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtModelCreate failed, ret: 0x%X", rt_ret);
return false;
}

// Create rtStream for rt_model_handle_
rt_ret = rtStreamCreate(&rt_model_stream_, davinci_model->GetPriority());
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtStreamCreate failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("rtStreamCreate end");

if (!InitStream(davinci_model)) {
return false;
}

if (!InitEvent(davinci_model->GetEventNum())) {
return false;
}

if (!InitLabel(davinci_model)) {
return false;
}

GELOGI("InitResource succ");
return true;
}

void RuntimeModel::GenerateTask(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model) {
GELOGI("GenerateTask start.");
if (davinci_model == nullptr) {
GELOGE(PARAM_INVALID, "davinci model is null");
return;
}
auto task_infos = davinci_model->GetTaskInfoList();
ModelContext model_context(device_id, session_id, davinci_model->GetPriority(), rt_model_handle_, rt_model_stream_,
stream_list_, label_list_, event_list_);
for (auto &task_info : task_infos) {
auto task = TaskFactory::GetInstance().Create(model_context, task_info);
task_list_.push_back(task);
}
GELOGI("GenerateTask succ.");
}

bool RuntimeModel::LoadTask() {
GELOGI("LoadTask start.");
for (auto &task : task_list_) {
if (task == nullptr) {
GELOGE(PARAM_INVALID, "task is null.");
continue;
}
bool ret = task->Distribute();
if (!ret) {
GELOGE(FAILED, "task distribute fail.");
return false;
}

uint32_t task_id = 0;
uint32_t stream_id = 0;
rtError_t rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X.", rt_ret);
return false;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
if (task->Args() != nullptr) {
std::shared_ptr<RuntimeInfo> runtime_tuple = nullptr;
GE_MAKE_SHARED(runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args()), return false);
auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple);
if (!emplace_ret.second) {
GELOGW("Task name exist:%s", task->task_name().c_str());
}
}
}
if (task_list_.empty()) {
GELOGE(FAILED, "Task list is empty");
return false;
}

GELOGI("LoadTask succ.");
return true;
}

bool RuntimeModel::LoadComplete() {
uint32_t task_id = 0;
uint32_t stream_id = 0;
auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret);
return RT_FAILED;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);

rt_ret = rtModelLoadComplete(rt_model_handle_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret);
return false;
}
return true;
}

bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model) {
bool status = InitResource(davinci_model);
if (!status) {
GELOGE(FAILED, "InitResource failed.");
return status;
}

status = InitDataInfo(davinci_model);
if (!status) {
GELOGE(FAILED, "InitDataInfo failed.");
return status;
}

status = InitOutputInfo(davinci_model);
if (!status) {
GELOGE(FAILED, "InitOutputInfo failed.");
return status;
}

status = InitConstantInfo(davinci_model);
if (!status) {
GELOGE(FAILED, "InitConstantInfo failed.");
return status;
}

GenerateTask(device_id, session_id, davinci_model);
return status;
}

bool RuntimeModel::DistributeTask() {
bool status = LoadTask();
if (!status) {
GELOGE(FAILED, "DistributeTask failed");
return false;
}
return true;
}

bool RuntimeModel::Run() {
GELOGI("Davinci task run start");
rtError_t ret = rtModelExecute(rt_model_handle_, rt_model_stream_, 0);
if (ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Model execute failed, ret = 0x%X", ret);
return false;
}

GELOGI("Run rtModelExecute success, ret = 0x%X", ret);

ret = rtStreamSynchronize(rt_model_stream_);
if (ret != RT_ERROR_NONE) {
if (ret == ACL_ERROR_RT_END_OF_SEQUENCE) {
GELOGI("Model stream ACL_ERROR_RT_END_OF_SEQUENCE signal received, ret = 0x%X", ret);
return true;
}
GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret);
return false;
}

GELOGI("Davinci task run succ.");
return true;
}

void RuntimeModel::RtModelUnbindStream() noexcept {
for (size_t i = 0; i < stream_list_.size(); i++) {
if (rtModelUnbindStream(rt_model_handle_, stream_list_[i]) != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Unbind stream from model failed! Index: %zu", i);
return;
}
}
}

void RuntimeModel::RtStreamDestory() noexcept {
if (rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Destroy stream for rt_model failed!");
return;
}

for (size_t i = 0; i < stream_list_.size(); i++) {
if (rtStreamDestroy(stream_list_[i]) != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Destroy stream failed! Index: %zu", i);
return;
}
}
}

void RuntimeModel::RtLabelDestory() noexcept {
for (size_t i = 0; i < label_list_.size(); i++) {
if (label_list_[i] == nullptr) {
continue;
}
if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i);
return;
}
}
}

void RuntimeModel::RtModelDestory() noexcept {
rtError_t ret = rtModelDestroy(rt_model_handle_);
if (ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret);
return;
}
}

void RuntimeModel::RtEventDestory() noexcept {
for (size_t i = 0; i < event_list_.size(); i++) {
if (rtEventDestroy(event_list_[i]) != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Destroy event failed! Index: %zu", i);
return;
}
}
}

bool RuntimeModel::InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model) { return true; }

bool RuntimeModel::InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model) {
if (davinci_model == nullptr) {
GELOGE(PARAM_INVALID, "davinci model is null");
return false;
}
output_info_list_ = davinci_model->GetOutputInfoList();
return true;
}

bool RuntimeModel::CopyInputData(const InputData &input_data) {
if (input_data.blobs.size() != data_info_list_.size()) {
GELOGE(PARAM_INVALID, "The input data list size (%zu) does not match the model input list size (%zu)",
input_data.blobs.size(), data_info_list_.size());
return false;
}

for (const auto &data_info : data_info_list_) {
if (data_info == nullptr) {
GELOGE(PARAM_INVALID, "data info is null.");
return false;
}

bool ret = CopyInputDataToModel(input_data.blobs, data_info);
if (!ret) {
GELOGE(FAILED, "Copy input data to model ret fail, data_info: %s, model id: %u", data_info->name.c_str(),
input_data.model_id);
return false;
}
}

return true;
}

bool RuntimeModel::CopyInputDataToModel(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) {
return true;
}

bool RuntimeModel::CopyHostData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) const {
GELOGI("Start CopyHostData.");
if (data.empty()) {
GELOGE(PARAM_INVALID, "data buffer is empty.");
return false;
}

if (data_info == nullptr) {
GELOGE(PARAM_INVALID, "data info is null.");
return false;
}

void *host_data_addr = data[data_info->index].data;
uint32_t copy_size = data[data_info->index].length;
GELOGD("data output tensor is aipp tensor,copy data only.");

const std::vector<uintptr_t> &outputs = data_info->output_addrs;
if (outputs.empty()) {
GELOGE(PARAM_INVALID, "Output addrs is empty.");
return false;
}

// Copy input data to data nodes
void *data_out_addr = reinterpret_cast<void *>(outputs[0]);

rtError_t rt_ret = rtMemcpy(data_out_addr, copy_size, host_data_addr, copy_size, RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

return true;
}

bool RuntimeModel::CopyTransData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) {
return true;
}

bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model) {
// Const no input, only 1 output, and this output has no data
// weight data copy to output mem
if (davinci_model == nullptr) {
GELOGE(PARAM_INVALID, "Davinci model is null.");
return false;
}
constant_info_list_ = davinci_model->GetConstantInfoList();

for (const auto &constant : constant_info_list_) {
if (constant == nullptr) {
GELOGE(PARAM_INVALID, "constant is null");
continue;
}
if (constant->output_tensors.empty()) {
GELOGE(PARAM_INVALID, "Output tensors is empty");
return false;
}

if (constant->weight_tensors.empty()) {
GELOGE(PARAM_INVALID, "Weight tensors is empty");
return false;
}

if (constant->output_tensors[0].size < constant->weight_data.size()) {
GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size,
constant->weight_data.size());
return false;
}

if (constant->weight_data.empty()) {
GELOGW("Const op:%s has no weight data.", constant->name.c_str());
continue;
}

if (constant->weight_tensors[0].datatype == DT_STRING) {
/// If tensor is a scaler, it's shape size if zero, according ge_tensor.cc.
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero
/// and that of unknown shape is zero too.
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not.
int64_t elem_num =
(constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize();
if (constant->weight_data.size() < sizeof(uint64_t)) {
GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)");
return false;
}
uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data()));
uint32_t head_len = kOffsetUnit * kStringHeadElems;
if (CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) {
GELOGE(FAILED, "Shape size is invalid");
return false;
}
int64_t offset = elem_num * head_len;
uintptr_t hbm_raw_data_base_addr = reinterpret_cast<uintptr_t>(constant->output_addrs[0]) + offset;
for (int64_t i = elem_num - 1; i >= 0; --i) {
buff[i * kStringHeadElems] = hbm_raw_data_base_addr + (buff[i * kStringHeadElems] - buff[0]);
}
}

rtError_t rt_ret = rtMemcpy(reinterpret_cast<void *>(constant->output_addrs[0]), constant->output_tensors[0].size,
constant->weight_data.data(), constant->weight_data.size(), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtGetFunctionByName failed, ret: 0x%X", rt_ret);
return false;
}
}

return true;
}

bool RuntimeModel::GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
std::vector<InputOutputDescInfo> *output_desc,
std::vector<uint32_t> *input_format, std::vector<uint32_t> *output_format) {
return true;
}

bool RuntimeModel::GetInputDescInfo(std::vector<InputOutputDescInfo> *input_desc, std::vector<uint32_t> *formats) {
return true;
}

bool RuntimeModel::GetOutputDescInfo(std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *formats) {
return true;
}

void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutputDescInfo *output,
uint32_t *format_result) {}

const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }

const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
} // namespace model_runner
} // namespace ge

+ 0
- 92
ge/ge_runtime/runtime_model.h View File

@@ -1,92 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_RUNTIME_MODEL_H_
#define GE_GE_RUNTIME_RUNTIME_MODEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ge_runtime/davinci_model.h"
#include "common/ge_types.h"
#include "runtime/base.h"
#include "runtime/rt_model.h"

namespace ge {
namespace model_runner {
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class Task;
class RuntimeModel {
public:
RuntimeModel() = default;
~RuntimeModel();

bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model);
bool DistributeTask();
bool LoadComplete();
const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const;
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }
rtModel_t GetModelHandle() const { return rt_model_handle_; }
bool Run();
bool CopyInputData(const InputData &input_data);
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format,
std::vector<uint32_t> *output_format);

private:
bool InitResource(std::shared_ptr<DavinciModel> &davinci_model);
void GenerateTask(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model);
bool LoadTask();
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model);
bool InitEvent(uint32_t event_num);
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model);
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model);
void RtModelUnbindStream() noexcept;
void RtStreamDestory() noexcept;
void RtModelDestory() noexcept;
void RtLabelDestory() noexcept;
void RtEventDestory() noexcept;
bool CopyInputDataToModel(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info);
bool CopyHostData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info) const;
bool CopyTransData(const std::vector<DataBuffer> &data, const std::shared_ptr<OpInfo> &data_info);
bool GetInputDescInfo(std::vector<InputOutputDescInfo> *input_desc, std::vector<uint32_t> *formats);
bool GetOutputDescInfo(std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *formats);
void CreateOutput(uint32_t index, const OpInfo &op_info, InputOutputDescInfo *output, uint32_t *format);

rtModel_t rt_model_handle_{};
rtStream_t rt_model_stream_{};

std::vector<rtStream_t> stream_list_{};
std::vector<rtLabel_t> label_list_{};
std::vector<rtEvent_t> event_list_{};

std::vector<std::shared_ptr<Task>> task_list_{};
std::vector<std::shared_ptr<OpInfo>> data_info_list_{};
std::vector<std::shared_ptr<OpInfo>> output_info_list_{};
std::vector<std::shared_ptr<OpInfo>> constant_info_list_{};

std::vector<uint32_t> task_id_list_{};
std::vector<uint32_t> stream_id_list_{};
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_;
};

} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_RUNTIME_MODEL_H_

+ 0
- 168
ge/ge_runtime/task/aicpu_task.cc View File

@@ -1,168 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/aicpu_task.h"
#include <vector>
#include "ge_runtime/task/task_factory.h"
#include "aicpu/common/aicpu_task_struct.h"

namespace ge {
namespace model_runner {
AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info)
: TaskRepeater<AicpuTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
args_(nullptr),
ext_info_(nullptr),
input_output_addr_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}

auto stream_list = model_context.stream_list();
if (stream_list.size() == 1) {
stream_ = stream_list[0];
} else if (stream_list.size() > task_info->stream_id()) {
stream_ = stream_list[task_info->stream_id()];
} else {
GELOGW("index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size());
}
}

AicpuTask::~AicpuTask() {
ReleaseRtMem(&args_);
ReleaseRtMem(&ext_info_);
}

bool AicpuTask::Distribute() {
GELOGI("InitAicpuTask start.");
vector<void *> io_addrs;
io_addrs.insert(io_addrs.end(), task_info_->input_data_addrs().begin(), task_info_->input_data_addrs().end());
io_addrs.insert(io_addrs.end(), task_info_->output_data_addrs().begin(), task_info_->output_data_addrs().end());
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size());
auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *));
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead);
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size;
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t);
uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size +
static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t);

aicpu::AicpuParamHead aicpu_param_head;
aicpu_param_head.length = args_size;
aicpu_param_head.ioAddrNum = io_addrs_num;
auto ext_info = task_info_->ext_info();
uint32_t ext_size = ext_info.size();
if (ext_info.empty()) {
aicpu_param_head.extInfoLength = 0;
aicpu_param_head.extInfoAddr = 0;
} else {
rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM);
if (flag != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag);
return false;
}

flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size,
RT_MEMCPY_HOST_TO_DEVICE);
if (flag != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag);
return false;
}

GELOGI("ext info size: %u", ext_size);
aicpu_param_head.extInfoLength = ext_size;
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_);
}

// Malloc device memory for args
rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", rt_ret);
return false;
}
GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task args data.", args_size)
// Memcpy AicpuParamHead
rt_ret = rtMemcpy(args_, sizeof(aicpu::AicpuParamHead), reinterpret_cast<void *>(&aicpu_param_head),
sizeof(aicpu::AicpuParamHead), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret);
return false;
}

// Memcpy io addrs
if (io_addrs_num != 0) {
rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset), io_addrs_size,
reinterpret_cast<void *>(io_addrs.data()), io_addrs_size, RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret);
return false;
}
}

// Memcpy node def
auto size = task_info_->node_def().size();
rt_ret =
rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t),
reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret);
return false;
}

// Memcpy node def
rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset),
task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()),
task_info_->node_def().size(), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret);
return false;
}

input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset);

auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
GELOGI(
"Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.",
args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag);
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()),
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_,
args_size, nullptr, stream_, dump_flag);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("Distribute AicpuTask end.");
return true;
}

void AicpuTask::ReleaseRtMem(void **ptr) noexcept {
if (ptr == nullptr || *ptr == nullptr) {
return;
}

rtError_t rt_ret = rtFree(*ptr);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "ReleaseRtMem failed, ret: 0x%X", rt_ret);
return;
}
*ptr = nullptr;
}

REGISTER_TASK(TaskInfoType::AICPU, AicpuTask, AicpuTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 50
ge/ge_runtime/task/aicpu_task.h View File

@@ -1,50 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_AICPU_TASK_H_
#define GE_GE_RUNTIME_TASK_AICPU_TASK_H_

#include <memory>
#include <string>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
public:
AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info);

~AicpuTask() override;

bool Distribute() override;

void *Args() override { return input_output_addr_; }

std::string task_name() const override { return task_info_->op_name(); }

private:
static void ReleaseRtMem(void **ptr) noexcept;

std::shared_ptr<AicpuTaskInfo> task_info_;
void *stream_;
void *args_;
void *ext_info_;
void *input_output_addr_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_AICPU_TASK_H_

+ 0
- 160
ge/ge_runtime/task/cce_task.cc View File

@@ -1,160 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/cce_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
CceTask::CceTask(const ModelContext &model_context, const std::shared_ptr<CceTaskInfo> &task_info)
: TaskRepeater<CceTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
stub_func_(nullptr),
args_(nullptr),
sm_desc_(nullptr),
flowtable_(nullptr),
is_flowtable_(false) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}

auto stream_list = model_context.stream_list();
if (stream_list.size() == 1) {
stream_ = stream_list[0];
} else if (stream_list.size() > task_info->stream_id()) {
stream_ = stream_list[task_info->stream_id()];
} else {
GELOGW("index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size());
}
}

CceTask::~CceTask() {
FreeRtMem(&args_);
FreeRtMem(&flowtable_);
rtError_t ret = (sm_desc_ != nullptr) ? rtMemFreeManaged(sm_desc_) : RT_ERROR_NONE;
if (ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret);
}
sm_desc_ = nullptr;
}

void CceTask::FreeRtMem(void **ptr) noexcept {
if (ptr == nullptr || *ptr == nullptr) {
return;
}
rtError_t ret = rtFree(*ptr);
if (ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret);
}

*ptr = nullptr;
}

bool CceTask::Distribute() {
GELOGI("Distribute CCETask start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream_ is null!");
return false;
}
// Get stub_func
if (task_info_->stub_func().empty()) {
GELOGE(PARAM_INVALID, "kernel_info->stub_func is empty!");
return false;
}

rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtGetFunctionByName failed, ret: 0x%X", rt_ret);
stub_func_ = nullptr;
return false;
}
GELOGI("CCETask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_);

// Flowtable
if (is_flowtable_) {
rt_ret = rtMalloc(&flowtable_, task_info_->flow_table().size(), RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task information.", task_info_->flow_table().size())

rt_ret = rtMemcpy(flowtable_, task_info_->flow_table().size(), task_info_->flow_table().data(),
task_info_->flow_table().size(), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

// Modify flowtable addr in args
auto args = const_cast<uint8_t *>(task_info_->args().data());
auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data()));

if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) {
GELOGE(FAILED, "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu",
static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size());
return false;
}

*(reinterpret_cast<uintptr_t *>(args + task_offset[0])) = reinterpret_cast<uintptr_t>(flowtable_);
}

// Args
rt_ret = rtMalloc(&args_, task_info_->args_size(), RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task information.", task_info_->args_size())

rt_ret = rtMemcpy(args_, task_info_->args_size(), task_info_->args().data(), task_info_->args_size(),
RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

// L2 sm_desc
if (!task_info_->sm_desc().empty()) {
rt_ret = rtMemAllocManaged(&sm_desc_, task_info_->sm_desc().size(), RT_MEMORY_SPM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), task_info_->sm_desc().data(),
task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
}

// Kernel launch
rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, task_info_->args_size(),
static_cast<rtSmDesc_t *>(sm_desc_), stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
return true;
}

REGISTER_TASK(TaskInfoType::CCE, CceTask, CceTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 47
ge/ge_runtime/task/cce_task.h View File

@@ -1,47 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_CCE_TASK_H_
#define GE_GE_RUNTIME_TASK_CCE_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class CceTask : public TaskRepeater<CceTaskInfo> {
public:
CceTask(const ModelContext &model_context, const std::shared_ptr<CceTaskInfo> &task_info);

~CceTask() override;

bool Distribute() override;

static void FreeRtMem(void **ptr) noexcept;

private:
std::shared_ptr<CceTaskInfo> task_info_;
void *stream_;
void *stub_func_;
void *args_;
void *sm_desc_;
void *flowtable_;
bool is_flowtable_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_CCE_TASK_H_

+ 0
- 61
ge/ge_runtime/task/event_record_task.cc View File

@@ -1,61 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/event_record_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
EventRecordTask::EventRecordTask(const ModelContext &model_context,
const std::shared_ptr<EventRecordTaskInfo> &task_info)
: TaskRepeater<EventRecordTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
event_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto event_list = model_context.event_list();
uint32_t stream_id = task_info->stream_id();
uint32_t event_id = task_info->event_id();
if (stream_id >= stream_list.size() || event_id >= event_list.size()) {
GELOGW("stream_list size:%zu, stream_id:%u, event_list size:%zu, event_id:%u", stream_list.size(), stream_id,
event_list.size(), event_id);
return;
}
stream_ = stream_list[stream_id];
event_ = event_list[event_id];
}

EventRecordTask::~EventRecordTask() {}

bool EventRecordTask::Distribute() {
GELOGI("EventRecordTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_,
task_info_->stream_id(), task_info_->event_id());
rtError_t rt_ret = rtEventRecord(event_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("Distribute end.");
return true;
}

REGISTER_TASK(TaskInfoType::EVENT_RECORD, EventRecordTask, EventRecordTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 41
ge/ge_runtime/task/event_record_task.h View File

@@ -1,41 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_EVENT_RECORD_TASK_H_
#define GE_GE_RUNTIME_TASK_EVENT_RECORD_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> {
public:
EventRecordTask(const ModelContext &model_context, const std::shared_ptr<EventRecordTaskInfo> &task_info);

~EventRecordTask() override;

bool Distribute() override;

private:
std::shared_ptr<EventRecordTaskInfo> task_info_;
rtStream_t stream_;
rtEvent_t event_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_EVENT_RECORD_TASK_H_

+ 0
- 67
ge/ge_runtime/task/event_wait_task.cc View File

@@ -1,67 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/event_wait_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
EventWaitTask::EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info)
: TaskRepeater<EventWaitTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
event_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto event_list = model_context.event_list();
uint32_t stream_id = task_info->stream_id();
uint32_t event_id = task_info->event_id();
if (stream_id >= stream_list.size() || event_id >= event_list.size()) {
GELOGW("stream_list size:%zu, stream_id:%u, event_list size:%zu, event_id:%u", stream_list.size(), stream_id,
event_list.size(), event_id);
return;
}
stream_ = stream_list[stream_id];
event_ = event_list[event_id];
}

EventWaitTask::~EventWaitTask() {}

bool EventWaitTask::Distribute() {
GELOGI("EventWaitTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_,
task_info_->stream_id(), task_info_->event_id());

rtError_t rt_ret = rtStreamWaitEvent(stream_, event_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtStreamWaitEvent failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtEventReset(event_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtEventReset failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("Distribute end.");
return true;
}

REGISTER_TASK(TaskInfoType::EVENT_WAIT, EventWaitTask, EventWaitTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 41
ge/ge_runtime/task/event_wait_task.h View File

@@ -1,41 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_EVENT_WAIT_TASK_H_
#define GE_GE_RUNTIME_TASK_EVENT_WAIT_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> {
public:
EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info);

~EventWaitTask() override;

bool Distribute() override;

private:
std::shared_ptr<EventWaitTaskInfo> task_info_;
rtStream_t stream_;
rtEvent_t event_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_EVENT_WAIT_TASK_H_

+ 0
- 268
ge/ge_runtime/task/hccl_task.cc View File

@@ -1,268 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/hccl_task.h"
#include <algorithm>
#include "ge_runtime/task/task_factory.h"
#include "common/opskernel/ops_kernel_info_store.h"
#include "common/opskernel/ge_task_info.h"

namespace ge {
namespace model_runner {
std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<HcclTask::StreamGuard>>>>
HcclTask::model_stream_mapping_;
std::mutex HcclTask::model_stream_mapping_mutex_;

HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info)
: TaskRepeater<HcclTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
workspace_mem_(nullptr),
rt_model_handle_(nullptr),
priority_(0),
secondary_stream_list_() {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}

priority_ = model_context.priority();
rt_model_handle_ = model_context.rt_model_handle();
auto stream_list = model_context.stream_list();

if (stream_list.size() == 1) {
stream_ = stream_list[0];
} else if (stream_list.size() > task_info->stream_id()) {
stream_ = stream_list[task_info->stream_id()];
} else {
GELOGW("Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size());
}
}

HcclTask::~HcclTask() {
if (workspace_mem_ != nullptr) {
rtError_t rt_ret = rtFree(workspace_mem_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree workspace_mem_ failed! ret: 0x%X.", rt_ret);
}
workspace_mem_ = nullptr;
}
}

bool HcclTask::Distribute() {
// Ops kernel info store
// Get privateDef and opsKernelStorePtr
GELOGI("Get custom info in modelTaskDef");
void *ops_kernel_store = task_info_->ops_kernel_store();
OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<OpsKernelInfoStore *>(ops_kernel_store);
if (ops_kernel_store == nullptr) {
GELOGE(PARAM_INVALID, "No hcom distribute function ptr and no ops kernel store.");
return false;
}

char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data()));
auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size());
GELOGI("The first address of the custom info, privateDef=%p", private_def);
SetSecondaryStream();

if (task_info_->workspace_size() > 0) {
rtError_t rt_ret = rtMalloc(&workspace_mem_, task_info_->workspace_size(), RT_MEMORYINFO_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
}

GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl.");
GETaskInfo ge_task;
ge_task.id = 0;
ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL);
ge_task.stream = stream_;

ge_task.kernelHcclInfo = std::vector<GETaskKernelHcclInfo>(1);
ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type();
ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr();
ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr();
ge_task.kernelHcclInfo[0].workSpaceAddr = workspace_mem_;
ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size();
ge_task.kernelHcclInfo[0].count = task_info_->count();
ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type());
ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type());
ge_task.kernelHcclInfo[0].rootId = task_info_->root_id();

std::vector<rtStream_t> secondary_stream_list;
std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(),
std::back_inserter(secondary_stream_list),
[](const std::shared_ptr<StreamGuard> &stream) -> rtStream_t { return stream->GetStream(); });
ge_task.kernelHcclInfo[0].hcclStreamList = secondary_stream_list;

ge_task.privateDef = private_def;
ge_task.privateDefLen = private_def_len;
ge_task.opsKernelStorePtr = ops_kernel_store;

auto result = ops_kernel_info_store->LoadTask(ge_task);
// tagHcclResult::HCCL_SUCCESS is 0
if (result != 0) {
GELOGE(INTERNAL_ERROR, "davinci_model : load task fail, return ret: %u", result);
return false;
}

GELOGI("Call function LoadTask end.");
return true;
}

bool HcclTask::SetSecondaryStream() {
const uint32_t master_stream_id = task_info_->stream_id();
const int64_t hccl_secondary_stream_num = task_info_->hccl_stream_num();
Status ret;
std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_);
if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) {
GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %u.", rt_model_handle_, master_stream_id);
ret = CreateStream(hccl_secondary_stream_num, master_stream_id);
if (!ret) {
GELOGE(RT_FAILED, "Create hccl stream failed.");
return false;
}
return true;
}

std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map =
model_stream_mapping_.at(rt_model_handle_);
auto iter = master_secondary_stream_map.find(master_stream_id);
if (iter != master_secondary_stream_map.end()) {
std::vector<std::weak_ptr<StreamGuard>> &secondary_stream_vec = iter->second;
auto lock_weak_ptr = [&secondary_stream_vec, this](int64_t index) -> bool {
auto stream = secondary_stream_vec[index].lock();
if (stream == nullptr) {
rtStream_t new_stream = nullptr;
bool ret = CreateStream(rt_model_handle_, &new_stream);
if (!ret) {
GELOGE(FAILED, "CreateStream failed.");
return false;
}
stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream);
GE_RT_FALSE_CHECK_NOTNULL(stream);
secondary_stream_vec[index] = stream;
}
secondary_stream_list_.push_back(stream);
return true;
};

if (static_cast<size_t>(hccl_secondary_stream_num) <= secondary_stream_vec.size()) {
GELOGI("Number of secondary stream is enough to be reused.");
for (int64_t i = 0; i < hccl_secondary_stream_num; ++i) {
if (!lock_weak_ptr(i)) {
GELOGE(FAILED, "Lock weak ptr failed.");
return false;
}
}
} else {
GELOGI("Need to reuse secondary stream and create new secondary stream.");
size_t created_stream_num = secondary_stream_vec.size();
for (size_t i = 0; i < secondary_stream_vec.size(); ++i) {
if (!lock_weak_ptr(i)) {
GELOGE(FAILED, "Lock weak ptr failed.");
return false;
}
}
ret = CreateStream(hccl_secondary_stream_num - created_stream_num, master_stream_id);
if (ret != SUCCESS) {
GELOGE(RT_FAILED, "Create hccl stream failed.");
return false;
}
}
GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num);
} else {
GELOGI("Need to create secondary stream for %s with new mainstream %u.", task_info_->op_name().c_str(),
master_stream_id);
ret = CreateStream(hccl_secondary_stream_num, master_stream_id);
if (!ret) {
GELOGE(RT_FAILED, "Create hccl stream failed.");
return false;
}
}
return true;
}

bool HcclTask::CreateStream(int64_t stream_num, int64_t master_stream_id) {
GELOGI("Start to create %ld hccl secondary stream.", stream_num);
for (int64_t i = 0; i < stream_num; ++i) {
rtStream_t stream = nullptr;
bool ret = CreateStream(rt_model_handle_, &stream);
if (!ret) {
GELOGE(FAILED, "CreateStream failed.");
return false;
}

GELOGD("hccl_stream addr is=%p", stream);
auto shared_stream = std::make_shared<StreamGuard>(rt_model_handle_, stream);
if (shared_stream == nullptr) {
GELOGE(FAILED, "MakeShared failed.");
return false;
}
SaveHcclSecondaryStream(master_stream_id, shared_stream);
secondary_stream_list_.push_back(shared_stream);
}
GELOGI("CreateStream success.");
return true;
}

bool HcclTask::CreateStream(rtModel_t model, rtStream_t *stream) const {
if (stream == nullptr) {
GELOGE(FAILED, "Output param stream is null.");
return false;
}

rtError_t rt_ret = rtStreamCreateWithFlags(stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
// Create secondary stream, inactive by default, activated by hccl
rt_ret = rtModelBindStream(model, *stream, RT_MODEL_WAIT_ACTIVE_STREAM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
return true;
}

void HcclTask::SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream) {
if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) {
model_stream_mapping_.emplace(rt_model_handle_, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>());
}
std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map =
model_stream_mapping_.at(rt_model_handle_);
master_secondary_stream_map[master_stream_id].emplace_back(stream);
}

HcclTask::StreamGuard::~StreamGuard() {
rtError_t rt_ret = rtModelUnbindStream(model_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Unbind stream from model failed!");
return;
}

rt_ret = rtStreamDestroy(stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Destroy stream failed!");
return;
}
}

REGISTER_TASK(TaskInfoType::HCCL, HcclTask, HcclTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 69
ge/ge_runtime/task/hccl_task.h View File

@@ -1,69 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_HCCL_TASK_H_
#define GE_GE_RUNTIME_TASK_HCCL_TASK_H_

#include <memory>
#include <set>
#include <map>
#include <vector>
#include <mutex>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class HcclTask : public TaskRepeater<HcclTaskInfo> {
public:
HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info);

~HcclTask() override;

bool Distribute() override;

private:
class StreamGuard;
bool SetSecondaryStream();
bool CreateStream(int64_t stream_num, int64_t master_stream_id);
bool CreateStream(rtModel_t model, rtStream_t *stream) const;
void SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream);

std::shared_ptr<HcclTaskInfo> task_info_;
void *stream_;
void *workspace_mem_;
rtModel_t rt_model_handle_;
int32_t priority_;
std::vector<std::shared_ptr<StreamGuard>> secondary_stream_list_;

// map<key: model pointer, value: map<key: primary stream id, value: vector<secondary stream pointer>>>
static std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>> model_stream_mapping_;
static std::mutex model_stream_mapping_mutex_;
};

class HcclTask::StreamGuard {
public:
StreamGuard(rtModel_t model, rtStream_t stream) : model_(model), stream_(stream) {}
~StreamGuard();
rtStream_t GetStream() const { return stream_; }

private:
rtModel_t model_;
rtStream_t stream_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_HCCL_TASK_H_

+ 0
- 117
ge/ge_runtime/task/label_goto_task.cc View File

@@ -1,117 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h"
#include "framework/common/util.h"

namespace ge {
namespace model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}

LabelGotoTask::~LabelGotoTask() {
GE_FREE_RT_LOG(label_info_);
GE_FREE_RT_LOG(index_value_);
}

bool LabelGotoTask::Distribute() {
GELOGI("LabelGotoTask Distribute start.");
if (!CheckParamValid()) {
return false;
}

const std::vector<void *> label_list = { label_ };
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

uint64_t branch_index = 0;
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size();
rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

bool LabelGotoTask::CheckParamValid() {
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}

if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}

if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
}

if (index_value_ != nullptr) {
GELOGE(PARAM_INVALID, "index_value_ has dirty data.");
return false;
}

return true;
}

REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 45
ge/ge_runtime/task/label_goto_task.h View File

@@ -1,45 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
public:
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info);

~LabelGotoTask() override;

bool Distribute() override;

private:
bool CheckParamValid();

std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_{nullptr};
void *label_{nullptr};
void *label_info_{nullptr};
void *index_value_{nullptr};
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_

+ 0
- 70
ge/ge_runtime/task/label_set_task.cc View File

@@ -1,70 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/label_set_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info)
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}

LabelSetTask::~LabelSetTask() {}

bool LabelSetTask::Distribute() {
GELOGI("LabelSetTask Distribute start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelSet(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo);

} // namespace model_runner
} // namespace ge

+ 0
- 41
ge/ge_runtime/task/label_set_task.h View File

@@ -1,41 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> {
public:
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info);

~LabelSetTask() override;

bool Distribute() override;

private:
std::shared_ptr<LabelSetTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_

+ 0
- 131
ge/ge_runtime/task/label_switch_task.cc View File

@@ -1,131 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/label_switch_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
const std::shared_ptr<LabelSwitchTaskInfo> &task_info)
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
all_label_resource_(),
label_info_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}

all_label_resource_ = model_context.label_list();
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
if (stream_id >= stream_list.size()) {
GELOGW("Stream id invalid.");
return;
}
stream_ = stream_list[stream_id];
}

LabelSwitchTask::~LabelSwitchTask() {
if (label_info_ != nullptr) {
rtError_t rt_ret = rtFree(label_info_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret);
}
label_info_ = nullptr;
}
}

bool LabelSwitchTask::Distribute() {
GELOGI("LabelSwitchTask Distribute start.");
if (!CheckParamValid()) {
return false;
}

const std::vector<uint32_t> &label_index_list = task_info_->label_list();
std::vector<void *> label_list(task_info_->label_size(), nullptr);

for (size_t i = 0; i < task_info_->label_size(); ++i) {
uint32_t label_index = label_index_list[i];
if (label_index >= all_label_resource_.size()) {
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index,
all_label_resource_.size());
return false;
}
label_list[i] = all_label_resource_[label_index];
GELOGI("Case %zu: label id %zu.", i, (size_t)label_index);
}

uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size();
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

bool LabelSwitchTask::CheckParamValid() {
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}

if (task_info_->label_list().empty()) {
GELOGE(PARAM_INVALID, "label_list is empty.");
return false;
}

if (task_info_->label_size() != task_info_->label_list().size()) {
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(),
task_info_->label_size());
return false;
}

if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) {
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size());
return false;
}

if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
}

return true;
}

REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo);

} // namespace model_runner
} // namespace ge

+ 0
- 44
ge/ge_runtime/task/label_switch_task.h View File

@@ -1,44 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> {
public:
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info);

~LabelSwitchTask() override;

bool Distribute() override;

private:
bool CheckParamValid();

std::shared_ptr<LabelSwitchTaskInfo> task_info_;
void *stream_;
std::vector<void *> all_label_resource_;
void *label_info_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_

+ 0
- 57
ge/ge_runtime/task/memcpy_async_task.cc View File

@@ -1,57 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/memcpy_async_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
MemcpyAsyncTask::MemcpyAsyncTask(const ModelContext &model_context,
const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info)
: TaskRepeater<MemcpyAsyncTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();

GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
if (stream_id >= stream_list.size()) {
GELOGW("Stream id invalid");
return;
}
stream_ = stream_list[stream_id];
}

MemcpyAsyncTask::~MemcpyAsyncTask() {}

bool MemcpyAsyncTask::Distribute() {
GELOGI("MemcpyAsyncTask Distribute start.");
GELOGI("dst_max:%lu, count:%lu, kind:%u.", task_info_->dst_max(), task_info_->count(), task_info_->kind());
rtError_t rt_ret = rtMemcpyAsync(task_info_->dst(), task_info_->dst_max(), task_info_->src(), task_info_->count(),
static_cast<rtMemcpyKind_t>(task_info_->kind()), stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("DistributeTask end");
return true;
}

REGISTER_TASK(TaskInfoType::MEMCPY_ASYNC, MemcpyAsyncTask, MemcpyAsyncTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 40
ge/ge_runtime/task/memcpy_async_task.h View File

@@ -1,40 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_
#define GE_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class MemcpyAsyncTask : public TaskRepeater<MemcpyAsyncTaskInfo> {
public:
MemcpyAsyncTask(const ModelContext &model_context, const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info);

~MemcpyAsyncTask() override;

bool Distribute() override;

private:
std::shared_ptr<MemcpyAsyncTaskInfo> task_info_;
rtStream_t stream_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_

+ 0
- 55
ge/ge_runtime/task/profiler_task.cc View File

@@ -1,55 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/profiler_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info)
: TaskRepeater<ProfilerTraceTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
if (stream_id >= stream_list.size()) {
GELOGW("Stream id invalid");
return;
}
stream_ = stream_list[stream_id];
}

ProfilerTask::~ProfilerTask() {}

bool ProfilerTask::Distribute() {
GELOGI("ProfilerTask Distribute start.");
GELOGI("logid = %lu, notify = %d, flat = %u.", task_info_->log_id(), task_info_->notify(), task_info_->flat());
rtError_t rt_ret = rtProfilerTrace(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("DistributeTask end");
return true;
}

REGISTER_TASK(TaskInfoType::PROFILER_TRACE, ProfilerTask, ProfilerTraceTaskInfo);

} // namespace model_runner
} // namespace ge

+ 0
- 40
ge/ge_runtime/task/profiler_task.h View File

@@ -1,40 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_PROFILER_TASK_H_
#define GE_GE_RUNTIME_TASK_PROFILER_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class ProfilerTask : public TaskRepeater<ProfilerTraceTaskInfo> {
public:
ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info);

~ProfilerTask() override;

bool Distribute() override;

private:
std::shared_ptr<ProfilerTraceTaskInfo> task_info_;
rtStream_t stream_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_PROFILER_TASK_H_

+ 0
- 60
ge/ge_runtime/task/stream_active_task.cc View File

@@ -1,60 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/stream_active_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
StreamActiveTask::StreamActiveTask(const ModelContext &model_context,
const std::shared_ptr<StreamActiveTaskInfo> &task_info)
: TaskRepeater<StreamActiveTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
active_stream_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
uint32_t active_stream_id = task_info->active_stream_id();
GELOGI("Stream list size:%zu, stream id:%u, active stream id:%u", stream_list.size(), stream_id, active_stream_id);
if (stream_id >= stream_list.size() || active_stream_id >= stream_list.size()) {
GELOGW("Stream id invalid");
return;
}
stream_ = stream_list[stream_id];
active_stream_ = stream_list[active_stream_id];
}

StreamActiveTask::~StreamActiveTask() {}

bool StreamActiveTask::Distribute() {
GELOGI("Distribute start");
GELOGI("Stream %u active %u.", task_info_->stream_id(), task_info_->active_stream_id());
rtError_t rt_ret = rtStreamActive(active_stream_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("DistributeTask end");
return true;
}

REGISTER_TASK(TaskInfoType::STREAM_ACTIVE, StreamActiveTask, StreamActiveTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 41
ge/ge_runtime/task/stream_active_task.h View File

@@ -1,41 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_
#define GE_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class StreamActiveTask : public TaskRepeater<StreamActiveTaskInfo> {
public:
StreamActiveTask(const ModelContext &model_context, const std::shared_ptr<StreamActiveTaskInfo> &task_info);

~StreamActiveTask() override;

bool Distribute() override;

private:
std::shared_ptr<StreamActiveTaskInfo> task_info_;
rtStream_t stream_;
rtStream_t active_stream_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_

+ 0
- 82
ge/ge_runtime/task/stream_switch_task.cc View File

@@ -1,82 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/stream_switch_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
StreamSwitchTask::StreamSwitchTask(const ModelContext &model_context,
const std::shared_ptr<StreamSwitchTaskInfo> &task_info)
: TaskRepeater<StreamSwitchTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
stream_list_() {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}

stream_list_ = model_context.stream_list();
if (stream_list_.size() == 1) {
stream_ = stream_list_[0];
} else if (stream_list_.size() > task_info->stream_id()) {
stream_ = stream_list_[task_info->stream_id()];
} else {
GELOGW("Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list_.size());
}
}

StreamSwitchTask::~StreamSwitchTask() {}

bool StreamSwitchTask::Distribute() {
GELOGI("Init StreamSwitchTask start.");
GELOGI("Stream %u active %ld.", task_info_->stream_id(), task_info_->true_stream_id());

if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream_ is null!");
return false;
}

if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) {
GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(),
stream_list_.size());
return false;
}

void *input = reinterpret_cast<void *>(task_info_->input_addr());
rtCondition_t cond = static_cast<rtCondition_t>(task_info_->cond());
void *value = reinterpret_cast<void *>(task_info_->value_addr());
rtStream_t true_stream = stream_list_[task_info_->true_stream_id()];
rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type());

GELOGI("InitStreamSwitchTask, cond:%d, trueStream:%p, trueStreamID:%ld, datatype:%ld.", cond, true_stream,
task_info_->true_stream_id(), task_info_->data_type());

GELOGI("StreamSwitchTask Distribute Start.");
rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("Distribute StreamSwitch, cond:%d, trueStream:%p, datatype:%ld.", cond, true_stream, task_info_->data_type());
return true;
}

REGISTER_TASK(TaskInfoType::STREAM_SWITCH, StreamSwitchTask, StreamSwitchTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 43
ge/ge_runtime/task/stream_switch_task.h View File

@@ -1,43 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_
#define GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_

#include <memory>
#include <vector>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> {
public:
StreamSwitchTask(const ModelContext &model_context, const std::shared_ptr<StreamSwitchTaskInfo> &task_info);

~StreamSwitchTask() override;

bool Distribute() override;

private:
std::shared_ptr<StreamSwitchTaskInfo> task_info_;

void *stream_;
std::vector<rtStream_t> stream_list_;
};

} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_

+ 0
- 58
ge/ge_runtime/task/task.h View File

@@ -1,58 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_TASK_H_
#define GE_GE_RUNTIME_TASK_TASK_H_

#include <memory>
#include <utility>
#include <vector>
#include <string>
#include "runtime/rt_model.h"
#include "ge_runtime/model_context.h"
#include "ge_runtime/task_info.h"
#include "external/runtime/rt_error_codes.h"

namespace ge {
namespace model_runner {
class Task {
public:
Task() {}

virtual ~Task() {}

virtual bool Distribute() = 0;

virtual void *Args() { return nullptr; }

virtual std::string task_name() const { return ""; }
};

template <class T>
class TaskRepeater : public Task {
static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!");

public:
TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {}

virtual ~TaskRepeater() {}

virtual bool Distribute() = 0;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_TASK_H_

+ 0
- 87
ge/ge_runtime/task/task_factory.h View File

@@ -1,87 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_TASK_FACTORY_H_
#define GE_GE_RUNTIME_TASK_TASK_FACTORY_H_

#include <functional>
#include <map>
#include <memory>
#include <unordered_map>
#include "common/ge_inner_error_codes.h"
#include "framework/common/debug/ge_log.h"
#include "ge_runtime/task_info.h"

namespace ge {
namespace model_runner {
class Task;
class ModelContext;
using TASK_CREATOR_FUN = std::function<std::shared_ptr<Task>(const ModelContext &, std::shared_ptr<TaskInfo>)>;

class TaskFactory {
private:
TaskFactory() {}
~TaskFactory() {}
void RegisterCreator(const TaskInfoType &type, const TASK_CREATOR_FUN &func) {
if (creator_map_.find(type) != creator_map_.end()) {
GELOGW("Creator type %d already exist", static_cast<int32_t>(type));
}
creator_map_[type] = func;
}

std::map<TaskInfoType, TASK_CREATOR_FUN> creator_map_;

public:
static TaskFactory &GetInstance() {
static TaskFactory instance;
return instance;
}

std::shared_ptr<Task> Create(const ModelContext &model_context, std::shared_ptr<TaskInfo> &task_info) const {
if (task_info == nullptr) {
GELOGE(FAILED, "task_info is null.");
return nullptr;
}

auto iter = creator_map_.find(task_info->type());
if (iter == creator_map_.end()) {
GELOGE(FAILED, "Unknow task type %d", static_cast<int32_t>(task_info->type()));
return nullptr;
}
return iter->second(model_context, task_info);
}

class Register {
public:
Register(const TaskInfoType &type, const TASK_CREATOR_FUN &func) {
GELOGI("regist type %d", static_cast<int32_t>(type));
TaskFactory::GetInstance().RegisterCreator(type, func);
}

~Register() {}
};
};

#define REGISTER_TASK(type, task_clazz, task_info_clazz) \
TaskFactory::Register g_##task_clazz##_register( \
type, [](const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) -> std::shared_ptr<Task> { \
std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \
return std::make_shared<task_clazz>(model_context, concrete_task_info); \
});

} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_

+ 0
- 112
ge/ge_runtime/task/tbe_task.cc View File

@@ -1,112 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/tbe_task.h"
#include <vector>
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info)
: TaskRepeater<TbeTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
stub_func_(nullptr),
args_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}

auto stream_list = model_context.stream_list();
if (stream_list.size() == 1) {
stream_ = stream_list[0];
} else if (stream_list.size() > task_info->stream_id()) {
stream_ = stream_list[task_info->stream_id()];
} else {
GELOGE(PARAM_INVALID, "Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size());
return;
}
}

TbeTask::~TbeTask() {
if (args_ != nullptr) {
rtError_t rt_ret = rtFree(args_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret);
}
args_ = nullptr;
}
}

bool TbeTask::Distribute() {
GELOGI("InitTbeTask start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream_ is null!");
return false;
}
// Get stub_func
if (task_info_->stub_func().empty()) {
GELOGE(PARAM_INVALID, "kernel_info->stub_func is empty!");
return false;
}

rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtGetFunctionByName failed, ret: %d", static_cast<int32_t>(rt_ret));
stub_func_ = nullptr;
return false;
}
GELOGI("TbeTask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_);

// Get args
std::vector<void *> tensor_device_addrs;
tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->input_data_addrs().begin(),
task_info_->input_data_addrs().end());
tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->output_data_addrs().begin(),
task_info_->output_data_addrs().end());
tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->workspace_addrs().begin(),
task_info_->workspace_addrs().end());
auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *));

rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtMalloc failed, ret: %d", static_cast<int32_t>(rt_ret));
return false;
}
GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "task args data.", args_size)

rt_ret = rtMemcpy(args_, args_size, reinterpret_cast<void *>(tensor_device_addrs.data()), args_size,
RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtMemcpy fail, ret 0x%X.", rt_ret);
return false;
}

GELOGI("DistributeTbeTask start.");
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag);
return true;
}

REGISTER_TASK(TaskInfoType::TBE, TbeTask, TbeTaskInfo);
} // namespace model_runner
} // namespace ge

+ 0
- 46
ge/ge_runtime/task/tbe_task.h View File

@@ -1,46 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_TBE_TASK_H_
#define GE_GE_RUNTIME_TASK_TBE_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class TbeTask : public TaskRepeater<TbeTaskInfo> {
public:
TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info);

~TbeTask() override;

bool Distribute() override;

void *Args() override { return args_; }

std::string task_name() const override { return task_info_->op_name(); }

private:
std::shared_ptr<TbeTaskInfo> task_info_;
void *stream_;
void *stub_func_;
void *args_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_TBE_TASK_H_

+ 0
- 113
inc/framework/ge_runtime/davinci_model.h View File

@@ -1,113 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_
#define INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_

#include <memory>
#include <vector>

#include "ge_runtime/op_info.h"
#include "ge_runtime/task_info.h"

namespace ge {
namespace model_runner {
class DavinciModel {
public:
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list,
const std::vector<std::shared_ptr<OpInfo>> &data_info_list,
const std::vector<std::shared_ptr<OpInfo>> &output_info_list,
const std::vector<std::shared_ptr<OpInfo>> &constant_info_list,
const std::vector<model_runner::OpInfoPtr> &variable_info_list,
const std::vector<uint32_t> &wait_active_stream_list,
const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0,
uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0,
uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0,
int32_t priority = 0)
: task_info_list_(task_info_list),
data_info_list_(data_info_list),
output_info_list_(output_info_list),
constant_info_list_(constant_info_list),
variable_info_list_(variable_info_list),
wait_active_stream_list_(wait_active_stream_list),
force_copy_stream_list_(force_copy_stream_list),
mem_size_(mem_size),
weight_size_(weight_size),
var_size_(var_size),
logic_mem_base_(logic_mem_base),
logic_weight_base_(logic_weight_base),
logic_var_base_(logic_var_base),
stream_num_(stream_num),
batch_num_(batch_num),
event_num_(event_num),
priority_(priority) {}
~DavinciModel() {}

uint64_t GetMemSize() const { return mem_size_; }
uint64_t GetWeightSize() const { return weight_size_; }
uint64_t GetVarSize() const { return var_size_; }

uintptr_t GetLogicMemBase() const { return logic_mem_base_; }
uintptr_t GetLogicWeightBase() const { return logic_weight_base_; }
uintptr_t GetLogicVarBase() const { return logic_var_base_; }

uint32_t GetStreamNum() const { return stream_num_; }
uint32_t GetBatchNum() const { return batch_num_; }
uint32_t GetEventNum() const { return event_num_; }

const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; }
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; }

int32_t GetPriority() const { return priority_; }

const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; }
const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; }
const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; }
const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; }
const std::vector<model_runner::OpInfoPtr> &GetVariableInfoList() const { return variable_info_list_; }

private:
std::vector<std::shared_ptr<TaskInfo>> task_info_list_;
std::vector<std::shared_ptr<OpInfo>> data_info_list_;
std::vector<std::shared_ptr<OpInfo>> output_info_list_;
std::vector<std::shared_ptr<OpInfo>> constant_info_list_;
std::vector<model_runner::OpInfoPtr> variable_info_list_;

std::vector<uint32_t> wait_active_stream_list_;
std::vector<uint32_t> force_copy_stream_list_;

uint64_t mem_size_;
uint64_t weight_size_;
uint64_t var_size_;

uintptr_t logic_mem_base_;
uintptr_t logic_weight_base_;
uintptr_t logic_var_base_;

uint32_t stream_num_;
uint32_t batch_num_;
uint32_t event_num_;

int32_t priority_;

// Disable to copy constructor and assignment operator
DavinciModel &operator=(const DavinciModel &) = delete;
DavinciModel(const DavinciModel &) = delete;
};
} // namespace model_runner
} // namespace ge

#endif // INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_

+ 0
- 68
inc/framework/ge_runtime/model_runner.h View File

@@ -1,68 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_
#define INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_

#include <memory>
#include <unordered_map>
#include <vector>

#include "common/ge_inner_error_codes.h"
#include "common/ge_types.h"
#include "ge_runtime/davinci_model.h"

namespace ge {
namespace model_runner {
class RuntimeModel;
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class ModelRunner {
public:
static ModelRunner &Instance();

bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener);

bool DistributeTask(uint32_t model_id);

bool LoadModelComplete(uint32_t model_id);

const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;

const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;

const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const;

void *GetModelHandle(uint32_t model_id) const;

bool UnloadModel(uint32_t model_id);

bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);

bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format,
std::vector<uint32_t> *output_format);

private:
ModelRunner() = default;
~ModelRunner() = default;

std::unordered_map<uint32_t, std::shared_ptr<RuntimeModel>> runtime_models_;
};
} // namespace model_runner
} // namespace ge

#endif // INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_

+ 0
- 72
inc/framework/ge_runtime/op_info.h View File

@@ -1,72 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_
#define INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_

#include <memory>
#include <string>
#include <vector>

namespace ge {
namespace model_runner {
struct TensorInfo {
int64_t GetShapeSize() const {
int64_t res = 1;
if (dims.empty()) {
return 0;
}
for (auto dim : dims) {
res *= dim;
}
return res;
}

int64_t GetDim(uint32_t index) {
if (index >= dims.size()) {
return 0;
}
return dims[index];
}

std::vector<int64_t> dims;
uint32_t datatype;
uint32_t format;
uint32_t real_dim_cnt;
uint32_t size;
bool is_output;
};

struct OpInfo {
uint32_t index;
std::string name;
std::string type;
bool var_is_broadcast;
std::vector<uintptr_t> input_addrs;
std::vector<uintptr_t> output_addrs;
std::vector<TensorInfo> input_tensors;
std::vector<TensorInfo> output_tensors;
std::vector<TensorInfo> weight_tensors;
std::vector<std::string> src_name;
std::vector<int64_t> src_index;
std::string weight_data;
};

using TensorInfoPtr = std::shared_ptr<TensorInfo>;
using OpInfoPtr = std::shared_ptr<OpInfo>;
} // namespace model_runner
} // namespace ge
#endif // INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_

+ 0
- 405
inc/framework/ge_runtime/task_info.h View File

@@ -1,405 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
#define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

#include <stdint.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cce/taskdown_api.h"

namespace ge {
namespace model_runner {
enum TaskInfoType {
CCE = 0,
TBE,
AICPU,
LABEL_SET,
LABEL_SWITCH,
LABEL_GOTO,
EVENT_RECORD,
EVENT_WAIT,
FUSION_START,
FUSION_END,
HCCL,
PROFILER_TRACE,
MEMCPY_ASYNC,
STREAM_SWITCH,
STREAM_ACTIVE,
// Insert new task type here
REVSERVED = 23
};

class TaskInfo {
public:
virtual ~TaskInfo() {}
uint32_t stream_id() const { return stream_id_; }
TaskInfoType type() const { return type_; }
std::string op_name() const { return op_name_; }
bool dump_flag() const { return dump_flag_; }

protected:
TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
: op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}

private:
std::string op_name_;
uint32_t stream_id_;
TaskInfoType type_;
bool dump_flag_;
};

class CceTaskInfo : public TaskInfo {
public:
CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
const std::vector<uint8_t> &args_offset, bool is_flowtable)
: TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
ctx_(ctx),
stub_func_(stub_func),
block_dim_(block_dim),
args_(args),
args_size_(args_size),
sm_desc_(sm_desc),
flow_table_(flow_table),
args_offset_(args_offset),
is_flowtable_(is_flowtable) {}
~CceTaskInfo() override {}

cce::ccOpContext cc_context() const { return ctx_; }
std::string stub_func() const { return stub_func_; }
uint32_t block_dim() const { return block_dim_; }
const std::vector<uint8_t> &args() const { return args_; }
uint32_t args_size() const { return args_size_; }
const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
const std::vector<uint8_t> &flow_table() const { return flow_table_; }
const std::vector<uint8_t> &args_offset() const { return args_offset_; }
bool is_flowtable() const { return is_flowtable_; }

private:
cce::ccOpContext ctx_;
std::string stub_func_;
uint32_t block_dim_;
std::vector<uint8_t> args_;
uint32_t args_size_;
std::vector<uint8_t> sm_desc_;
std::vector<uint8_t> flow_table_;
std::vector<uint8_t> args_offset_;
bool is_flowtable_;
};

class TbeTaskInfo : public TaskInfo {
public:
TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
stub_func_(stub_func),
block_dim_(block_dim),
args_(args),
args_size_(args_size),
sm_desc_(sm_desc),
binary_(binary),
binary_size_(binary_size),
meta_data_(meta_data),
input_data_addrs_(input_data_addrs),
output_data_addrs_(output_data_addrs),
workspace_addrs_(workspace_addrs) {}
~TbeTaskInfo() override {}

const std::string &stub_func() const { return stub_func_; }
uint32_t block_dim() const { return block_dim_; }
const std::vector<uint8_t> &args() const { return args_; }
uint32_t args_size() const { return args_size_; }
const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
void *binary() const { return binary_; }
uint32_t binary_size() const { return binary_size_; }
const std::vector<uint8_t> &meta_data() const { return meta_data_; }
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; }

void SetBinary(void *binary, uint32_t binary_size) {
binary_ = binary;
binary_size_ = binary_size;
}

private:
std::string stub_func_;
uint32_t block_dim_;
std::vector<uint8_t> args_;
uint32_t args_size_;
std::vector<uint8_t> sm_desc_;
void *binary_;
uint32_t binary_size_;
std::vector<uint8_t> meta_data_;
std::vector<void *> input_data_addrs_;
std::vector<void *> output_data_addrs_;
std::vector<void *> workspace_addrs_;
};

class AicpuTaskInfo : public TaskInfo {
public:
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
so_name_(so_name),
kernel_name_(kernel_name),
node_def_(node_def),
ext_info_(ext_info),
input_data_addrs_(input_data_addrs),
output_data_addrs_(output_data_addrs) {}
~AicpuTaskInfo() override {}

const std::string &so_name() const { return so_name_; }
const std::string &kernel_name() const { return kernel_name_; }
const std::string &node_def() const { return node_def_; }
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
const std::string &ext_info() const { return ext_info_; }

private:
std::string so_name_;
std::string kernel_name_;
std::string node_def_;
std::string ext_info_;
std::vector<void *> input_data_addrs_;
std::vector<void *> output_data_addrs_;
};

class LabelSetTaskInfo : public TaskInfo {
public:
LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
~LabelSetTaskInfo() override {}
uint32_t label_id() const { return label_id_; }

private:
uint32_t label_id_;
};

class LabelGotoTaskInfo : public TaskInfo {
public:
LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
~LabelGotoTaskInfo() override {}
uint32_t label_id() const { return label_id_; }

private:
uint32_t label_id_;
};

class LabelSwitchTaskInfo : public TaskInfo {
public:
LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
const std::vector<uint32_t> &label_list, void *cond)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
label_size_(label_size),
label_list_(label_list),
cond_(cond) {}
~LabelSwitchTaskInfo() override {}
uint32_t label_size() const { return label_size_; }
const std::vector<uint32_t> &label_list() const { return label_list_; }
void *cond() const { return cond_; }

private:
uint32_t label_size_;
std::vector<uint32_t> label_list_;
void *cond_;
};

class EventTaskInfo : public TaskInfo {
public:
uint32_t event_id() const { return event_id_; }

protected:
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
~EventTaskInfo() override {}

uint32_t event_id_;
};

class EventRecordTaskInfo : public EventTaskInfo {
public:
EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
~EventRecordTaskInfo() override {}
};

class EventWaitTaskInfo : public EventTaskInfo {
public:
EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
~EventWaitTaskInfo() override {}
};

class FusionStartTaskInfo : public TaskInfo {
public:
explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
~FusionStartTaskInfo() override {}
};

class FusionEndTaskInfo : public TaskInfo {
public:
explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
~FusionEndTaskInfo() override {}
};

class HcclTaskInfo : public TaskInfo {
public:
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num,
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
hccl_type_(hccl_type),
input_data_addr_(input_data_addr),
output_data_addr_(output_data_addr),
workspace_size_(workspace_size),
hccl_stream_num_(hccl_stream_num),
private_def_(private_def),
ops_kernel_store_(ops_kernel_store),
count_(count),
root_id_(root_id),
op_type_(op_type),
data_type_(data_type),
group_(group) {}
~HcclTaskInfo() override {}

const std::string &hccl_type() const { return hccl_type_; }
void *input_data_addr() const { return input_data_addr_; }
void *output_data_addr() const { return output_data_addr_; }
int64_t workspace_size() const { return workspace_size_; }
int64_t hccl_stream_num() const { return hccl_stream_num_; }
const std::vector<uint8_t> &private_def() const { return private_def_; }
void *ops_kernel_store() const { return ops_kernel_store_; }
int32_t count() const { return count_; }
int64_t root_id() const { return root_id_; }
int64_t op_type() const { return op_type_; }
int64_t data_type() const { return data_type_; }
const std::string &group() const { return group_; }

private:
std::string hccl_type_;
void *input_data_addr_;
void *output_data_addr_;
int64_t workspace_size_;
int64_t hccl_stream_num_;
std::vector<uint8_t> private_def_;
void *ops_kernel_store_;
int32_t count_;
int64_t root_id_;
int64_t op_type_;
int64_t data_type_;
std::string group_;
};

class ProfilerTraceTaskInfo : public TaskInfo {
public:
ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
log_id_(log_id),
notify_(notify),
flat_(flat) {}
~ProfilerTraceTaskInfo() override {}

uint64_t log_id() const { return log_id_; }
bool notify() const { return notify_; }
uint32_t flat() const { return flat_; }

private:
uint64_t log_id_;
bool notify_;
uint32_t flat_;
};

class MemcpyAsyncTaskInfo : public TaskInfo {
public:
MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
uint64_t count, uint32_t kind, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
dst_(dst),
dst_max_(dst_max),
src_(src),
count_(count),
kind_(kind) {}
~MemcpyAsyncTaskInfo() override {}

void *dst() const { return dst_; }
uint64_t dst_max() const { return dst_max_; }
void *src() const { return src_; }
uint64_t count() const { return count_; }
uint32_t kind() const { return kind_; }

private:
void *dst_;
uint64_t dst_max_;
void *src_;
uint64_t count_;
int32_t kind_;
};

class StreamSwitchTaskInfo : public TaskInfo {
public:
StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
void *value_addr, int64_t cond, int64_t data_type)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
true_stream_id_(true_stream_id),
input_addr_(input_addr),
value_addr_(value_addr),
cond_(cond),
data_type_(data_type) {}
~StreamSwitchTaskInfo() override {}

int64_t true_stream_id() const { return true_stream_id_; }
void *input_addr() const { return input_addr_; }
void *value_addr() const { return value_addr_; }
int64_t cond() const { return cond_; }
int64_t data_type() const { return data_type_; }

private:
int64_t true_stream_id_;
void *input_addr_;
void *value_addr_;
int64_t cond_;
int64_t data_type_;
};

class StreamActiveTaskInfo : public TaskInfo {
public:
StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
~StreamActiveTaskInfo() override {}

uint32_t active_stream_id() const { return active_stream_id_; }

private:
uint32_t active_stream_id_;
};
} // namespace model_runner
} // namespace ge

#endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

Loading…
Cancel
Save