Browse Source

Modify feature on execute.

tags/v1.2.0
zhangxiaokun 3 years ago
parent
commit
0eb67dad8e
2 changed files with 29 additions and 3 deletions
  1. +27
    -0
      ge/graph/load/new_model_manager/davinci_model.cc
  2. +2
    -3
      ge/graph/load/new_model_manager/davinci_model.h

+ 27
- 0
ge/graph/load/new_model_manager/davinci_model.cc View File

@@ -108,6 +108,7 @@ std::mutex DavinciModel::tvm_bin_mutex_;
DavinciModel::DavinciModel(int32_t priority, const std::shared_ptr<ModelListener> &listener) DavinciModel::DavinciModel(int32_t priority, const std::shared_ptr<ModelListener> &listener)
: weights_mem_base_(nullptr), : weights_mem_base_(nullptr),
var_mem_base_(nullptr), var_mem_base_(nullptr),
fixed_mem_base_(0),
mem_base_(nullptr), mem_base_(nullptr),
is_inner_mem_base_(false), is_inner_mem_base_(false),
is_inner_weight_base_(false), is_inner_weight_base_(false),
@@ -670,6 +671,7 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size
data_inputer_ = new (std::nothrow) DataInputer(); data_inputer_ = new (std::nothrow) DataInputer();
GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, MEMALLOC_FAILED, "data_inputer_ is nullptr."); GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, MEMALLOC_FAILED, "data_inputer_ is nullptr.");
} }
fixed_mem_base_ = reinterpret_cast<uintptr_t>(mem_base_);
GE_TIMESTAMP_END(InitModelMem, "GraphLoader::InitModelMem"); GE_TIMESTAMP_END(InitModelMem, "GraphLoader::InitModelMem");


for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { for (const ge::NodePtr &node : compute_graph->GetDirectNode()) {
@@ -2828,7 +2830,32 @@ Status DavinciModel::CreateKnownZeroCopyMap(const vector<void *> &inputs, const
return SUCCESS; return SUCCESS;
} }


void DavinciModel::SetTotalIOAddrs(const vector<void *> &io_addrs) {
if (fixed_mem_base_ == reinterpret_cast<uintptr_t>(mem_base_)) {
total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end());
return;
}

for (size_t i = 0; i < io_addrs.size(); ++i) {
uintptr_t addr = reinterpret_cast<uintptr_t>(io_addrs[i]);
if ((fixed_mem_base_ <= addr) && (addr < fixed_mem_base_ + runtime_param_.mem_size)) {
total_io_addrs_.emplace_back(mem_base_ + (addr - fixed_mem_base_));
} else {
total_io_addrs_.emplace_back(io_addrs[i]);
}
}
}

Status DavinciModel::UpdateKnownZeroCopyAddr(vector<void *> &total_io_addrs) { Status DavinciModel::UpdateKnownZeroCopyAddr(vector<void *> &total_io_addrs) {
if (fixed_mem_base_ != reinterpret_cast<uintptr_t>(mem_base_)) {
for (size_t i = 0; i < total_io_addrs.size(); ++i) {
uintptr_t addr = reinterpret_cast<uintptr_t>(total_io_addrs[i]);
if ((fixed_mem_base_ <= addr) && (addr < fixed_mem_base_ + runtime_param_.mem_size)) {
total_io_addrs[i] = mem_base_ + (addr - fixed_mem_base_);
}
}
}

for (size_t i = 0; i < total_io_addrs.size(); ++i) { for (size_t i = 0; i < total_io_addrs.size(); ++i) {
auto it_in = knonw_input_data_info_.find(total_io_addrs[i]); auto it_in = knonw_input_data_info_.find(total_io_addrs[i]);
if (it_in != knonw_input_data_info_.end()) { if (it_in != knonw_input_data_info_.end()) {


+ 2
- 3
ge/graph/load/new_model_manager/davinci_model.h View File

@@ -503,9 +503,7 @@ class DavinciModel {
void *cur_args = static_cast<char *>(args_) + offset; void *cur_args = static_cast<char *>(args_) + offset;
return cur_args; return cur_args;
} }
void SetTotalIOAddrs(const vector<void *> &io_addrs) {
total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end());
}
void SetTotalIOAddrs(const vector<void *> &io_addrs);
void SetHybridArgsSize(uint32_t args_size) { total_hybrid_args_size_ += args_size; } void SetHybridArgsSize(uint32_t args_size) { total_hybrid_args_size_ += args_size; }
uint32_t GetHybridArgsSize() { uint32_t GetHybridArgsSize() {
return total_hybrid_args_size_; return total_hybrid_args_size_;
@@ -555,6 +553,7 @@ class DavinciModel {
uint8_t *weights_mem_base_; uint8_t *weights_mem_base_;
uint8_t *var_mem_base_; uint8_t *var_mem_base_;
// memory address of model // memory address of model
uintptr_t fixed_mem_base_; // Initial of mem_base_, keep forever.
uint8_t *mem_base_; uint8_t *mem_base_;
uint8_t *p2p_mem_base_; uint8_t *p2p_mem_base_;
bool is_inner_mem_base_; bool is_inner_mem_base_;


Loading…
Cancel
Save