Browse Source

!37 Add GetStreamIdList in ge_runtime

Merge pull request !37 from caifubi/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 4 years ago
parent
commit
c54db4343f
4 changed files with 17 additions and 0 deletions
  1. +2
    -0
      inc/framework/ge_runtime/model_runner.h
  2. +11
    -0
      src/ge/ge_runtime/model_runner.cc
  3. +2
    -0
      src/ge/ge_runtime/runtime_model.cc
  4. +2
    -0
      src/ge/ge_runtime/runtime_model.h

+ 2
- 0
inc/framework/ge_runtime/model_runner.h View File

@@ -38,6 +38,8 @@ class ModelRunner {

const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;

const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;

bool UnloadModel(uint32_t model_id);

bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);


+ 11
- 0
src/ge/ge_runtime/model_runner.cc View File

@@ -60,6 +60,17 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
return model_iter->second->GetTaskIdList();
}

const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
static const std::vector<uint32_t> empty_ret;
return empty_ret;
}

return model_iter->second->GetStreamIdList();
}

bool ModelRunner::UnloadModel(uint32_t model_id) {
auto iter = runtime_models_.find(model_id);
if (iter != runtime_models_.end()) {


+ 2
- 0
src/ge/ge_runtime/runtime_model.cc View File

@@ -220,6 +220,7 @@ bool RuntimeModel::LoadTask() {
return false;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
}
if (task_list_.empty()) {
GELOGE(FAILED, "Task list is empty");
@@ -507,5 +508,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp

const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }

const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
} // namespace model_runner
} // namespace ge

+ 2
- 0
src/ge/ge_runtime/runtime_model.h View File

@@ -36,6 +36,7 @@ class RuntimeModel {

bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model);
const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const;
bool Run();
bool CopyInputData(const InputData &input_data);
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
@@ -77,6 +78,7 @@ class RuntimeModel {
std::vector<std::shared_ptr<OpInfo>> constant_info_list_{};

std::vector<uint32_t> task_id_list_{};
std::vector<uint32_t> stream_id_list_{};
};

} // namespace model_runner


Loading…
Cancel
Save