| @@ -177,37 +177,44 @@ class AicpuTaskInfo : public TaskInfo { | |||
| std::vector<void *> output_data_addrs_; | |||
| }; | |||
| class LabelTaskInfo : public TaskInfo { | |||
| class LabelSetTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(stream_id, TaskInfoType::LABEL_SET), label_id_(label_id) {} | |||
| ~LabelSetTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| protected: | |||
| LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) | |||
| : TaskInfo(stream_id, type), label_id_(label_id) {} | |||
| virtual ~LabelTaskInfo() override {} | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelSetTaskInfo : public LabelTaskInfo { | |||
| class LabelGotoTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} | |||
| ~LabelSetTaskInfo() override {} | |||
| LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(stream_id, TaskInfoType::LABEL_GOTO), label_id_(label_id) {} | |||
| ~LabelGotoTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelSwitchTaskInfo : public LabelTaskInfo { | |||
| class LabelSwitchTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} | |||
| LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_size, const std::vector<uint32_t> &label_list, void *cond) | |||
| : TaskInfo(stream_id, TaskInfoType::LABEL_SWITCH), | |||
| label_size_(label_size), | |||
| label_list_(label_list), | |||
| cond_(cond) {} | |||
| ~LabelSwitchTaskInfo() override {} | |||
| }; | |||
| uint32_t label_size() { return label_size_; }; | |||
| const std::vector<uint32_t> &label_list() { return label_list_; }; | |||
| void *cond() { return cond_; }; | |||
| class LabelGotoTaskInfo : public LabelTaskInfo { | |||
| public: | |||
| LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} | |||
| ~LabelGotoTaskInfo() override {} | |||
| private: | |||
| uint32_t label_size_; | |||
| std::vector<uint32_t> label_list_; | |||
| void *cond_; | |||
| }; | |||
| class EventTaskInfo : public TaskInfo { | |||
| @@ -116,23 +116,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||
| return true; | |||
| } | |||
| bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||
| GELOGI("batch number:%u.", batch_num); | |||
| for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||
| rtLabel_t rt_lLabel = nullptr; | |||
| rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||
| return false; | |||
| bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||
| label_list_.resize(davinci_model->GetBatchNum()); | |||
| for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||
| if (task_info == nullptr) { | |||
| GELOGE(PARAM_INVALID, "task_info is null."); | |||
| continue; | |||
| } | |||
| if (task_info->type() != TaskInfoType::LABEL_SET) { | |||
| continue; | |||
| } | |||
| auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||
| if (rt_lLabel == nullptr) { | |||
| GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||
| if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||
| GELOGE(PARAM_INVALID, "Invalid stream id."); | |||
| return false; | |||
| } | |||
| label_list_.emplace_back(rt_lLabel); | |||
| rtLabel_t rt_label = nullptr; | |||
| rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| label_list_[label_set_task_info->label_id()] = rt_label; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -164,7 +175,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| return false; | |||
| } | |||
| if (!InitLabel(davinci_model->GetBatchNum())) { | |||
| if (!InitLabel(davinci_model)) { | |||
| return false; | |||
| } | |||
| @@ -48,7 +48,7 @@ class RuntimeModel { | |||
| bool LoadTask(); | |||
| bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitEvent(uint32_t event_num); | |||
| bool InitLabel(uint32_t batch_num); | |||
| bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_goto_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t label_id = task_info->label_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
| GELOGW("Stream/Label id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelGotoTask::~LabelGotoTask() {} | |||
| bool LabelGotoTask::Distribute() { | |||
| GELOGI("LabelGotoTask Distribute start."); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
| public: | |||
| LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||
| ~LabelGotoTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_set_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t label_id = task_info->label_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
| GELOGW("Stream/Label id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelSetTask::~LabelSetTask() {} | |||
| bool LabelSetTask::Distribute() { | |||
| GELOGI("LabelSetTask Distribute start."); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtLabelSet(label_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||
| public: | |||
| LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||
| ~LabelSetTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<LabelSetTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| @@ -0,0 +1,131 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_switch_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
| const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| all_label_resource_(), | |||
| label_info_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| all_label_resource_ = model_context.label_list(); | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| if (stream_id >= stream_list.size()) { | |||
| GELOGW("Stream id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| } | |||
| LabelSwitchTask::~LabelSwitchTask() { | |||
| if (label_info_ != nullptr) { | |||
| rtError_t rt_ret = rtFree(label_info_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||
| } | |||
| label_info_ = nullptr; | |||
| } | |||
| } | |||
| bool LabelSwitchTask::Distribute() { | |||
| GELOGI("LabelSwitchTask Distribute start."); | |||
| if (!CheckParamValid()) { | |||
| return false; | |||
| } | |||
| const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||
| std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||
| for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||
| uint32_t label_index = label_index_list[i]; | |||
| if (label_index >= all_label_resource_.size()) { | |||
| GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||
| all_label_resource_.size()); | |||
| return false; | |||
| } | |||
| label_list[i] = all_label_resource_[label_index]; | |||
| GELOGI("Case %zu: label id %zu.", i, label_index); | |||
| } | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||
| rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| bool LabelSwitchTask::CheckParamValid() { | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (task_info_->label_list().empty()) { | |||
| GELOGE(PARAM_INVALID, "label_list is empty."); | |||
| return false; | |||
| } | |||
| if (task_info_->label_size() != task_info_->label_list().size()) { | |||
| GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||
| task_info_->label_size()); | |||
| return false; | |||
| } | |||
| if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||
| GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||
| return false; | |||
| } | |||
| if (label_info_ != nullptr) { | |||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||
| public: | |||
| LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||
| ~LabelSwitchTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| bool CheckParamValid(); | |||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||
| void *stream_; | |||
| std::vector<void *> all_label_resource_; | |||
| void *label_info_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||