| @@ -177,37 +177,44 @@ class AicpuTaskInfo : public TaskInfo { | |||||
| std::vector<void *> output_data_addrs_; | std::vector<void *> output_data_addrs_; | ||||
| }; | }; | ||||
| class LabelTaskInfo : public TaskInfo { | |||||
| class LabelSetTaskInfo : public TaskInfo { | |||||
| public: | public: | ||||
| LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
| : TaskInfo(stream_id, TaskInfoType::LABEL_SET), label_id_(label_id) {} | |||||
| ~LabelSetTaskInfo() override {} | |||||
| uint32_t label_id() const { return label_id_; } | uint32_t label_id() const { return label_id_; } | ||||
| protected: | |||||
| LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) | |||||
| : TaskInfo(stream_id, type), label_id_(label_id) {} | |||||
| virtual ~LabelTaskInfo() override {} | |||||
| private: | |||||
| uint32_t label_id_; | uint32_t label_id_; | ||||
| }; | }; | ||||
| class LabelSetTaskInfo : public LabelTaskInfo { | |||||
| class LabelGotoTaskInfo : public TaskInfo { | |||||
| public: | public: | ||||
| LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} | |||||
| ~LabelSetTaskInfo() override {} | |||||
| LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
| : TaskInfo(stream_id, TaskInfoType::LABEL_GOTO), label_id_(label_id) {} | |||||
| ~LabelGotoTaskInfo() override {} | |||||
| uint32_t label_id() const { return label_id_; } | |||||
| private: | |||||
| uint32_t label_id_; | |||||
| }; | }; | ||||
| class LabelSwitchTaskInfo : public LabelTaskInfo { | |||||
| class LabelSwitchTaskInfo : public TaskInfo { | |||||
| public: | public: | ||||
| LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} | |||||
| LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_size, const std::vector<uint32_t> &label_list, void *cond) | |||||
| : TaskInfo(stream_id, TaskInfoType::LABEL_SWITCH), | |||||
| label_size_(label_size), | |||||
| label_list_(label_list), | |||||
| cond_(cond) {} | |||||
| ~LabelSwitchTaskInfo() override {} | ~LabelSwitchTaskInfo() override {} | ||||
| }; | |||||
| uint32_t label_size() { return label_size_; }; | |||||
| const std::vector<uint32_t> &label_list() { return label_list_; }; | |||||
| void *cond() { return cond_; }; | |||||
| class LabelGotoTaskInfo : public LabelTaskInfo { | |||||
| public: | |||||
| LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} | |||||
| ~LabelGotoTaskInfo() override {} | |||||
| private: | |||||
| uint32_t label_size_; | |||||
| std::vector<uint32_t> label_list_; | |||||
| void *cond_; | |||||
| }; | }; | ||||
| class EventTaskInfo : public TaskInfo { | class EventTaskInfo : public TaskInfo { | ||||
| @@ -116,23 +116,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||||
| GELOGI("batch number:%u.", batch_num); | |||||
| for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||||
| rtLabel_t rt_lLabel = nullptr; | |||||
| rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||||
| return false; | |||||
| bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||||
| label_list_.resize(davinci_model->GetBatchNum()); | |||||
| for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||||
| if (task_info == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "task_info is null."); | |||||
| continue; | |||||
| } | |||||
| if (task_info->type() != TaskInfoType::LABEL_SET) { | |||||
| continue; | |||||
| } | } | ||||
| auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||||
| if (rt_lLabel == nullptr) { | |||||
| GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||||
| if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||||
| GELOGE(PARAM_INVALID, "Invalid stream id."); | |||||
| return false; | return false; | ||||
| } | } | ||||
| label_list_.emplace_back(rt_lLabel); | |||||
| rtLabel_t rt_label = nullptr; | |||||
| rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| label_list_[label_set_task_info->label_id()] = rt_label; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -164,7 +175,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (!InitLabel(davinci_model->GetBatchNum())) { | |||||
| if (!InitLabel(davinci_model)) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -48,7 +48,7 @@ class RuntimeModel { | |||||
| bool LoadTask(); | bool LoadTask(); | ||||
| bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | ||||
| bool InitEvent(uint32_t event_num); | bool InitEvent(uint32_t event_num); | ||||
| bool InitLabel(uint32_t batch_num); | |||||
| bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||||
| bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
| bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
| bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/label_goto_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| label_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| auto label_list = model_context.label_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| uint32_t label_id = task_info->label_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
| GELOGW("Stream/Label id invalid."); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| label_ = label_list[label_id]; | |||||
| } | |||||
| LabelGotoTask::~LabelGotoTask() {} | |||||
| bool LabelGotoTask::Distribute() { | |||||
| GELOGI("LabelGotoTask Distribute start."); | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||||
| return false; | |||||
| } | |||||
| if (label_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label is null!"); | |||||
| return false; | |||||
| } | |||||
| rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end."); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||||
| public: | |||||
| LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||||
| ~LabelGotoTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| void *label_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/label_set_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||||
| : TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| label_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| auto label_list = model_context.label_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| uint32_t label_id = task_info->label_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
| GELOGW("Stream/Label id invalid."); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| label_ = label_list[label_id]; | |||||
| } | |||||
| LabelSetTask::~LabelSetTask() {} | |||||
| bool LabelSetTask::Distribute() { | |||||
| GELOGI("LabelSetTask Distribute start."); | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||||
| return false; | |||||
| } | |||||
| if (label_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label is null!"); | |||||
| return false; | |||||
| } | |||||
| rtError_t rt_ret = rtLabelSet(label_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end."); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||||
| public: | |||||
| LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||||
| ~LabelSetTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<LabelSetTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| void *label_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
| @@ -0,0 +1,131 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/label_switch_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||||
| const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||||
| : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| all_label_resource_(), | |||||
| label_info_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| all_label_resource_ = model_context.label_list(); | |||||
| auto stream_list = model_context.stream_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| if (stream_id >= stream_list.size()) { | |||||
| GELOGW("Stream id invalid."); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| } | |||||
| LabelSwitchTask::~LabelSwitchTask() { | |||||
| if (label_info_ != nullptr) { | |||||
| rtError_t rt_ret = rtFree(label_info_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||||
| } | |||||
| label_info_ = nullptr; | |||||
| } | |||||
| } | |||||
| bool LabelSwitchTask::Distribute() { | |||||
| GELOGI("LabelSwitchTask Distribute start."); | |||||
| if (!CheckParamValid()) { | |||||
| return false; | |||||
| } | |||||
| const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||||
| std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||||
| for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||||
| uint32_t label_index = label_index_list[i]; | |||||
| if (label_index >= all_label_resource_.size()) { | |||||
| GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||||
| all_label_resource_.size()); | |||||
| return false; | |||||
| } | |||||
| label_list[i] = all_label_resource_[label_index]; | |||||
| GELOGI("Case %zu: label id %zu.", i, label_index); | |||||
| } | |||||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||||
| rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end."); | |||||
| return true; | |||||
| } | |||||
| bool LabelSwitchTask::CheckParamValid() { | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||||
| return false; | |||||
| } | |||||
| if (task_info_->label_list().empty()) { | |||||
| GELOGE(PARAM_INVALID, "label_list is empty."); | |||||
| return false; | |||||
| } | |||||
| if (task_info_->label_size() != task_info_->label_list().size()) { | |||||
| GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||||
| task_info_->label_size()); | |||||
| return false; | |||||
| } | |||||
| if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||||
| GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||||
| return false; | |||||
| } | |||||
| if (label_info_ != nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||||
| public: | |||||
| LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||||
| ~LabelSwitchTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| bool CheckParamValid(); | |||||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| std::vector<void *> all_label_resource_; | |||||
| void *label_info_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||