@@ -87,6 +87,7 @@ const uint32_t kDumpL1FusionOpMByteSize = 2097152; // 2 * 1024 * 1024 | |||||
const uint32_t kDumpFlagOfL1Fusion = 0; | const uint32_t kDumpFlagOfL1Fusion = 0; | ||||
const char *const kDefaultBatchLable = "Batch_default"; | const char *const kDefaultBatchLable = "Batch_default"; | ||||
const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; | const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; | ||||
const char *const kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; | |||||
const int32_t kInvalidStream = -1; | const int32_t kInvalidStream = -1; | ||||
const uint32_t kEndOfSequence = 0x0704000a; | const uint32_t kEndOfSequence = 0x0704000a; | ||||
const uint32_t kEndOfSequenceNew = 507005; | const uint32_t kEndOfSequenceNew = 507005; | ||||
@@ -867,6 +868,10 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||||
GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); | GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (InitRealSizeAndShapeInfo(compute_graph, node) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Init real size and shape failed, Name: %s", op_desc->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
continue; | continue; | ||||
} | } | ||||
@@ -1143,16 +1148,24 @@ Status DavinciModel::InitNetOutput(const ComputeGraphPtr &graph, const NodePtr & | |||||
real_virtual_addrs_.insert(real_addr); | real_virtual_addrs_.insert(real_addr); | ||||
} | } | ||||
} | } | ||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::InitRealSizeAndShapeInfo(const ComputeGraphPtr &compute_graph, const NodePtr &node) { | |||||
if (node->GetName().find(kMultiBatchNodePostfix) != string::npos) { | |||||
GELOGD("No need to get size and shape of netoutput in subgraph."); | |||||
return SUCCESS; | |||||
} | |||||
GELOGD("Start init real size and shape info of %s.", node->GetName().c_str()); | |||||
GetAllGearsInfo(node); | GetAllGearsInfo(node); | ||||
if (is_getnext_sink_dynamic_) { | if (is_getnext_sink_dynamic_) { | ||||
GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS, | GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS, | ||||
GELOGE(PARAM_INVALID, "Failed to get info of getdynamicdims node."); return PARAM_INVALID;); | GELOGE(PARAM_INVALID, "Failed to get info of getdynamicdims node."); return PARAM_INVALID;); | ||||
} | } | ||||
if (is_online_infer_dynamic_) { | if (is_online_infer_dynamic_) { | ||||
GE_IF_BOOL_EXEC(GetGearAndRealOutSizeInfo(input_count, node) != SUCCESS, | |||||
GE_IF_BOOL_EXEC(GetGearAndRealOutSizeInfo(compute_graph, node) != SUCCESS, | |||||
GELOGE(PARAM_INVALID, "Failed to get gear and real out size info."); return PARAM_INVALID;); | GELOGE(PARAM_INVALID, "Failed to get gear and real out size info."); return PARAM_INVALID;); | ||||
GE_IF_BOOL_EXEC(GetGearAndRealOutShapeInfo(input_count, op_desc) != SUCCESS, | |||||
GE_IF_BOOL_EXEC(GetGearAndRealOutShapeInfo(compute_graph, node) != SUCCESS, | |||||
GELOGE(PARAM_INVALID, "Failed to get gear and real out shape info."); return PARAM_INVALID;); | GELOGE(PARAM_INVALID, "Failed to get gear and real out shape info."); return PARAM_INVALID;); | ||||
} | } | ||||
@@ -1171,7 +1184,7 @@ void DavinciModel::GetAllGearsInfo(const NodePtr &node) { | |||||
if (shape_str.empty()) { | if (shape_str.empty()) { | ||||
continue; | continue; | ||||
} | } | ||||
std::vector<int64_t> gear_info; | |||||
std::vector<int32_t> gear_info; | |||||
std::vector<std::string> dims = ge::StringUtils::Split(shape_str, ','); | std::vector<std::string> dims = ge::StringUtils::Split(shape_str, ','); | ||||
for (const auto &dim : dims) { | for (const auto &dim : dims) { | ||||
if (dim.empty()) { | if (dim.empty()) { | ||||
@@ -1187,6 +1200,7 @@ void DavinciModel::GetAllGearsInfo(const NodePtr &node) { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
Status DavinciModel::GetGetDynamicDimsNodeInfo(const NodePtr &node) { | Status DavinciModel::GetGetDynamicDimsNodeInfo(const NodePtr &node) { | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
size_t input_count = node->GetAllInDataAnchors().size(); | size_t input_count = node->GetAllInDataAnchors().size(); | ||||
@@ -1224,11 +1238,11 @@ Status DavinciModel::GetGetDynamicDimsNodeInfo(const NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DavinciModel::GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node) { | |||||
GELOGD("Start get gear and real output size info of %s, input count is %zu.", node->GetName().c_str(), input_count); | |||||
Status DavinciModel::GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
GELOGD("Start get gear and real output size info of %s.", node->GetName().c_str()); | |||||
merge_nodes_gear_and_real_out_size_info_.clear(); | merge_nodes_gear_and_real_out_size_info_.clear(); | ||||
for (size_t idx = 0; idx < input_count; ++idx) { | |||||
auto in_anchor = node->GetAllInDataAnchors().at(idx); | |||||
size_t idx = 0; | |||||
for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); | auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); | ||||
if (peer_out_anchor == nullptr) { | if (peer_out_anchor == nullptr) { | ||||
continue; | continue; | ||||
@@ -1236,89 +1250,106 @@ Status DavinciModel::GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr | |||||
auto peer_node = peer_out_anchor->GetOwnerNode(); | auto peer_node = peer_out_anchor->GetOwnerNode(); | ||||
auto op_desc = peer_node->GetOpDesc(); | auto op_desc = peer_node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
if ((peer_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { | |||||
if (GetRealOutputSizeOfMerge(idx, peer_node) != SUCCESS) { | |||||
if ((peer_node->GetType() == CASE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { | |||||
if (GetRealOutputSizeOfCase(graph, idx, peer_node) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Get real output size of %s failed.", peer_node->GetName().c_str()); | GELOGE(PARAM_INVALID, "Get real output size of %s failed.", peer_node->GetName().c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
} | } | ||||
idx++; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DavinciModel::GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node) { | |||||
GELOGD("Start get output size of %s, which is %zu input to netoutput.", merge_node->GetName().c_str(), input_index); | |||||
std::map<vector<int64_t>, int64_t> gear_and_real_out_size_info; | |||||
for (auto &in_anchor : merge_node->GetAllInDataAnchors()) { | |||||
auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto in_node = peer_out_anchor->GetOwnerNode(); | |||||
GELOGD("Input node of merge is %s.", in_node->GetName().c_str()); | |||||
auto op_desc = in_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
string batch_label; | |||||
if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
size_t batch_index = static_cast<size_t>(stoi(batch_label.substr(batch_label.rfind('_') + 1))); | |||||
GELOGD("Batch index of %s is %zu.", op_desc->GetName().c_str(), batch_index); | |||||
if (batch_index > all_gears_info_.size()) { | |||||
GELOGE(PARAM_INVALID, "The value of ATTR_NAME_BATCH_LABEL is invalid."); | |||||
return PARAM_INVALID; | |||||
} | |||||
const vector<int64_t> output_size_list = ModelUtils::GetOutputSize(op_desc); | |||||
int output_index = ge::AnchorUtils::GetIdx(peer_out_anchor); | |||||
auto tensor_desc = op_desc->GetOutputDescPtr(output_index); | |||||
GE_CHECK_NOTNULL(tensor_desc); | |||||
int64_t data_size = 0; | |||||
if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, data_size) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Get tensor size in bytes failed."); | |||||
return FAILED; | |||||
Status DavinciModel::GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, | |||||
const NodePtr &case_node) { | |||||
GELOGD("Start get output size of %s, which is %zu input to netoutput.", case_node->GetName().c_str(), input_index); | |||||
const auto &func_desc = case_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(func_desc); | |||||
std::map<vector<int32_t>, int64_t> gear_and_real_out_size_info; | |||||
for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | |||||
const auto &subgraph = graph->GetSubgraph(name); | |||||
if (subgraph == nullptr) { | |||||
GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s.", name.c_str()); | |||||
return GE_GRAPH_EMPTY_SUBGRAPH; | |||||
} | |||||
for (auto &node : subgraph->GetDirectNode()) { | |||||
if (node->GetType() == NETOUTPUT) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
string batch_label; | |||||
if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
size_t batch_index = static_cast<size_t>(stoi(batch_label.substr(batch_label.rfind('_') + 1))); | |||||
GELOGD("Batch index of %s is %zu.", op_desc->GetName().c_str(), batch_index); | |||||
if (batch_index > all_gears_info_.size()) { | |||||
GELOGE(PARAM_INVALID, "The value of ATTR_NAME_BATCH_LABEL is invalid."); | |||||
return PARAM_INVALID; | |||||
} | |||||
const vector<int64_t> input_size_list = ModelUtils::GetInputSize(op_desc); | |||||
auto tensor_desc = op_desc->GetInputDescPtr(input_index); | |||||
GE_CHECK_NOTNULL(tensor_desc); | |||||
int64_t data_size = 0; | |||||
if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, data_size) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Get tensor size in bytes failed."); | |||||
return FAILED; | |||||
} | |||||
gear_and_real_out_size_info[all_gears_info_[batch_index]] = data_size; | |||||
GELOGD("Get real gear index is: %zu, gear info is %s, size is %ld, tensor size is %ld", | |||||
batch_index, formats::JoinToString(all_gears_info_[batch_index]).c_str(), | |||||
input_size_list[input_index], data_size); | |||||
} | |||||
break; | |||||
} | } | ||||
gear_and_real_out_size_info[all_gears_info_[batch_index]] = data_size; | |||||
GELOGD("Get real gear index is: %zu, gear info is %s, size is %ld, tensor size is %ld", | |||||
batch_index, formats::JoinToString(all_gears_info_[batch_index]).c_str(), | |||||
output_size_list[output_index], data_size); | |||||
} | } | ||||
} | } | ||||
merge_nodes_gear_and_real_out_size_info_[input_index] = gear_and_real_out_size_info; | merge_nodes_gear_and_real_out_size_info_[input_index] = gear_and_real_out_size_info; | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DavinciModel::GetGearAndRealOutShapeInfo(size_t input_count, const OpDescPtr &op_desc) { | |||||
GELOGD("Start to get dynamic output dims of %s.", op_desc->GetName().c_str()); | |||||
Status DavinciModel::GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
GELOGD("Start to get dynamic output dims of %s.", node->GetName().c_str()); | |||||
merge_nodes_gear_and_real_out_shape_info_.clear(); | merge_nodes_gear_and_real_out_shape_info_.clear(); | ||||
std::vector<std::string> dynamic_output_shape_info; | |||||
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_shape_info)) { | |||||
GELOGD("Can not get dynamic output dims attr"); | |||||
return SUCCESS; | |||||
} | |||||
GELOGI("Dynamic output shape info is %s", formats::JoinToString(dynamic_output_shape_info).c_str()); | |||||
std::vector<vector<int64_t>> dynamic_output_shape; | |||||
ParseDynamicOutShape(dynamic_output_shape_info, dynamic_output_shape); | |||||
// idx: input_index to netoutput | |||||
for (size_t idx = 0; idx < input_count; ++idx) { | |||||
std::map<vector<int64_t>, vector<int64_t>> gear_and_real_out_shape_info; | |||||
for (auto &it : dynamic_output_shape) { | |||||
auto gear_index = static_cast<size_t>(it[0]); | |||||
if (gear_index > all_gears_info_.size()) { | |||||
GELOGE(PARAM_INVALID, "The value of cur index: %zu is invalid.", static_cast<size_t>(it[0])); | |||||
return PARAM_INVALID; | |||||
size_t idx = 0; | |||||
for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto peer_node = peer_out_anchor->GetOwnerNode(); | |||||
auto op_desc = peer_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if ((peer_node->GetType() == CASE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { | |||||
std::vector<std::string> dynamic_output_shape_info; | |||||
if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_shape_info)) { | |||||
GELOGD("Can not get dynamic output dims attr from %s.", node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | } | ||||
GELOGI("Dynamic output shape info is %s", formats::JoinToString(dynamic_output_shape_info).c_str()); | |||||
std::vector<vector<int64_t>> dynamic_output_shape; | |||||
ParseDynamicOutShape(dynamic_output_shape_info, dynamic_output_shape); | |||||
std::map<vector<int32_t>, vector<int64_t>> gear_and_real_out_shape_info; | |||||
for (auto &it : dynamic_output_shape) { | |||||
auto gear_index = static_cast<size_t>(it[0]); | |||||
if (gear_index > all_gears_info_.size()) { | |||||
GELOGE(PARAM_INVALID, "The value of cur index: %zu is invalid.", static_cast<size_t>(it[0])); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (static_cast<size_t>(it[1]) == idx) { | |||||
vector<int64_t> output_shape; | |||||
for (size_t i = 2; i < it.size(); ++i) { | |||||
output_shape.emplace_back(it[i]); | |||||
if (static_cast<size_t>(it[1]) == idx) { | |||||
vector<int64_t> output_shape; | |||||
for (size_t i = 2; i < it.size(); ++i) { | |||||
output_shape.emplace_back(it[i]); | |||||
} | |||||
gear_and_real_out_shape_info[all_gears_info_[gear_index]] = output_shape; | |||||
GELOGD("Get real gear index is: %zu, gear info is %s, output shape is %s.", | |||||
gear_index, formats::JoinToString(all_gears_info_[gear_index]).c_str(), | |||||
formats::JoinToString(output_shape).c_str()); | |||||
} | } | ||||
gear_and_real_out_shape_info[all_gears_info_[gear_index]] = output_shape; | |||||
GELOGD("Get real gear index is: %zu, gear info is %s, output shape is %s.", | |||||
gear_index, formats::JoinToString(all_gears_info_[gear_index]).c_str(), | |||||
formats::JoinToString(output_shape).c_str()); | |||||
} | } | ||||
merge_nodes_gear_and_real_out_shape_info_[idx] = gear_and_real_out_shape_info; | |||||
} | } | ||||
merge_nodes_gear_and_real_out_shape_info_[idx] = gear_and_real_out_shape_info; | |||||
idx++; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -1962,7 +1993,7 @@ void DavinciModel::CreateOutput(uint32_t index, const OpDescPtr &op_desc, InputO | |||||
uint32_t &format_result) { | uint32_t &format_result) { | ||||
/// netoutput input tensor desc | /// netoutput input tensor desc | ||||
GE_IF_BOOL_EXEC(op_desc->GetInputDescPtr(index) == nullptr, GELOGE(FAILED, "OpDesc GetInputDescPtr is nullptr"); | GE_IF_BOOL_EXEC(op_desc->GetInputDescPtr(index) == nullptr, GELOGE(FAILED, "OpDesc GetInputDescPtr is nullptr"); | ||||
return ); | |||||
return); | |||||
Format format = op_desc->GetInputDescPtr(index)->GetFormat(); | Format format = op_desc->GetInputDescPtr(index)->GetFormat(); | ||||
GeShape shape = op_desc->GetInputDescPtr(index)->GetShape(); | GeShape shape = op_desc->GetInputDescPtr(index)->GetShape(); | ||||
DataType data_type = op_desc->GetInputDescPtr(index)->GetDataType(); | DataType data_type = op_desc->GetInputDescPtr(index)->GetDataType(); | ||||
@@ -2567,7 +2598,7 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b | |||||
GELOGD("Reinit cur dynamic dims when getnext sink dynamic."); | GELOGD("Reinit cur dynamic dims when getnext sink dynamic."); | ||||
cur_dynamic_dims_.clear(); | cur_dynamic_dims_.clear(); | ||||
cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_); | cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_); | ||||
auto ret = rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t), | |||||
auto ret = rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int32_t), | |||||
netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST); | netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST); | ||||
GE_CHK_RT_RET(ret); | GE_CHK_RT_RET(ret); | ||||
} | } | ||||
@@ -2668,11 +2699,11 @@ void *DavinciModel::Run(DavinciModel *model) { | |||||
GE_IF_BOOL_EXEC(current_data.blobs.empty(), break); | GE_IF_BOOL_EXEC(current_data.blobs.empty(), break); | ||||
auto shape_data_buffer_data = current_data.blobs.back().data; | auto shape_data_buffer_data = current_data.blobs.back().data; | ||||
auto shape_data_buffer_length = current_data.blobs.back().length; | auto shape_data_buffer_length = current_data.blobs.back().length; | ||||
model->cur_dynamic_dims_.assign(reinterpret_cast<int64_t *>(shape_data_buffer_data), | |||||
reinterpret_cast<int64_t *>(shape_data_buffer_data) + | |||||
shape_data_buffer_length / sizeof(int64_t)); | |||||
model->cur_dynamic_dims_.assign(reinterpret_cast<int32_t *>(shape_data_buffer_data), | |||||
reinterpret_cast<int32_t *>(shape_data_buffer_data) + | |||||
shape_data_buffer_length / sizeof(int32_t)); | |||||
GELOGD("Data: cur dynamic dims is %s", formats::JoinToString(model->cur_dynamic_dims_).c_str()); | GELOGD("Data: cur dynamic dims is %s", formats::JoinToString(model->cur_dynamic_dims_).c_str()); | ||||
delete[] reinterpret_cast<int64_t *>(current_data.blobs.back().data); | |||||
delete[] reinterpret_cast<int32_t *>(current_data.blobs.back().data); | |||||
current_data.blobs.pop_back(); | current_data.blobs.pop_back(); | ||||
} | } | ||||
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_PRE_PROC_END)); | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_PRE_PROC_END)); | ||||
@@ -864,11 +864,13 @@ class DavinciModel { | |||||
void ParseDynamicOutShape(const vector<string> &str_info, vector<vector<int64_t>> &vec_info); | void ParseDynamicOutShape(const vector<string> &str_info, vector<vector<int64_t>> &vec_info); | ||||
bool IsGetNextSinkDynamic(const OpDescPtr &op_desc); | bool IsGetNextSinkDynamic(const OpDescPtr &op_desc); | ||||
Status InitRealSizeAndShapeInfo(const ComputeGraphPtr &compute_graph, const NodePtr &node); | |||||
void GetAllGearsInfo(const NodePtr &node); | void GetAllGearsInfo(const NodePtr &node); | ||||
Status GetGetDynamicDimsNodeInfo(const NodePtr &node); | Status GetGetDynamicDimsNodeInfo(const NodePtr &node); | ||||
Status GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node); | |||||
Status GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node); | |||||
Status GetGearAndRealOutShapeInfo(size_t input_count, const OpDescPtr &op_desc); | |||||
Status GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
Status GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, const NodePtr &case_node); | |||||
Status GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
bool is_weight_mem_has_inited_; | bool is_weight_mem_has_inited_; | ||||
bool is_feature_map_mem_has_inited_; | bool is_feature_map_mem_has_inited_; | ||||
@@ -1021,15 +1023,15 @@ class DavinciModel { | |||||
bool is_new_model_desc_{false}; | bool is_new_model_desc_{false}; | ||||
bool is_online_infer_dynamic_ = false; | bool is_online_infer_dynamic_ = false; | ||||
bool is_getnext_sink_dynamic_ = false; | bool is_getnext_sink_dynamic_ = false; | ||||
vector<int64_t> cur_dynamic_dims_; | |||||
vector<int32_t> cur_dynamic_dims_; | |||||
void *netoutput_last_input_addr_ = nullptr; | void *netoutput_last_input_addr_ = nullptr; | ||||
int64_t netoutput_last_input_size_ = 0; | int64_t netoutput_last_input_size_ = 0; | ||||
size_t shape_of_cur_dynamic_dims_ = 0; | size_t shape_of_cur_dynamic_dims_ = 0; | ||||
// key: input_index: input is merge node; value: each gear info and each output size | // key: input_index: input is merge node; value: each gear info and each output size | ||||
map<size_t, map<vector<int64_t>, int64_t>> merge_nodes_gear_and_real_out_size_info_; | |||||
map<size_t, map<vector<int32_t>, int64_t>> merge_nodes_gear_and_real_out_size_info_; | |||||
// key: input_index: input is merge node; value: each gear info and each output shape | // key: input_index: input is merge node; value: each gear info and each output shape | ||||
map<size_t, map<vector<int64_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_; | |||||
vector<vector<int64_t>> all_gears_info_; | |||||
map<size_t, map<vector<int32_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_; | |||||
vector<vector<int32_t>> all_gears_info_; | |||||
multimap<uint32_t, uint32_t> op_id_map_; | multimap<uint32_t, uint32_t> op_id_map_; | ||||
vector<ProfileInfo> profile_list_; | vector<ProfileInfo> profile_list_; | ||||
@@ -460,8 +460,8 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d | |||||
Status ModelManager::GetCurDynamicDims(const vector<vector<int64_t>> &user_real_input_dims, | Status ModelManager::GetCurDynamicDims(const vector<vector<int64_t>> &user_real_input_dims, | ||||
const vector<pair<string, vector<int64_t>>> &user_input_dims, | const vector<pair<string, vector<int64_t>>> &user_input_dims, | ||||
vector<int64_t> &cur_dynamic_dims) { | |||||
GELOGD(" Start get cur dynamic dims."); | |||||
vector<int32_t> &cur_dynamic_dims) { | |||||
GELOGD("Start get cur dynamic dims."); | |||||
if (user_real_input_dims.size() != user_input_dims.size()) { | if (user_real_input_dims.size() != user_input_dims.size()) { | ||||
GELOGE(INTERNAL_ERROR, | GELOGE(INTERNAL_ERROR, | ||||
"The input count of user: %zu should be equal to the data count of graph: %zu", | "The input count of user: %zu should be equal to the data count of graph: %zu", | ||||
@@ -478,7 +478,7 @@ Status ModelManager::GetCurDynamicDims(const vector<vector<int64_t>> &user_real_ | |||||
} | } | ||||
for (size_t j = 0; j < user_input_dims.at(i).second.size(); ++j) { | for (size_t j = 0; j < user_input_dims.at(i).second.size(); ++j) { | ||||
if (user_input_dims.at(i).second.at(j) < 0) { | if (user_input_dims.at(i).second.at(j) < 0) { | ||||
cur_dynamic_dims.emplace_back(user_real_input_dims[i][j]); | |||||
cur_dynamic_dims.emplace_back(static_cast<int32_t>(user_real_input_dims[i][j])); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -523,7 +523,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputT | |||||
input_data.blobs.push_back(data); | input_data.blobs.push_back(data); | ||||
} | } | ||||
if (!GetLocalOmgContext().user_input_dims.empty() && GetLocalOmgContext().need_multi_batch) { | if (!GetLocalOmgContext().user_input_dims.empty() && GetLocalOmgContext().need_multi_batch) { | ||||
std::vector<int64_t> cur_dynamic_dims; | |||||
std::vector<int32_t> cur_dynamic_dims; | |||||
if (!GetLocalOmgContext().user_real_input_dims.empty()) { | if (!GetLocalOmgContext().user_real_input_dims.empty()) { | ||||
if (GetCurDynamicDims(GetLocalOmgContext().user_real_input_dims, GetLocalOmgContext().user_input_dims, | if (GetCurDynamicDims(GetLocalOmgContext().user_real_input_dims, GetLocalOmgContext().user_input_dims, | ||||
cur_dynamic_dims) != SUCCESS) { | cur_dynamic_dims) != SUCCESS) { | ||||
@@ -531,9 +531,9 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputT | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
DataBuffer data; | DataBuffer data; | ||||
data.data = new(std::nothrow) int64_t[cur_dynamic_dims.size()]; | |||||
data.data = new(std::nothrow) int32_t[cur_dynamic_dims.size()]; | |||||
GE_CHECK_NOTNULL(data.data); | GE_CHECK_NOTNULL(data.data); | ||||
uint64_t length = static_cast<uint64_t>(cur_dynamic_dims.size() * sizeof(int64_t)); | |||||
uint32_t length = static_cast<uint32_t>(cur_dynamic_dims.size() * sizeof(int32_t)); | |||||
GE_CHK_BOOL_EXEC(memcpy_s(data.data, length, cur_dynamic_dims.data(), length) == EOK, return INTERNAL_ERROR, | GE_CHK_BOOL_EXEC(memcpy_s(data.data, length, cur_dynamic_dims.data(), length) == EOK, return INTERNAL_ERROR, | ||||
"Failed to memcpy data."); | "Failed to memcpy data."); | ||||
data.length = length; | data.length = length; | ||||
@@ -126,14 +126,14 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief Get cur_dynamic_dims for all input. | /// @brief Get cur_dynamic_dims for all input. | ||||
/// @param [in] vector<vector<uint64_t>> &user_real_input_dims: dims info of all user_inputs. | |||||
/// @param [in] vector<vector<int64_t>> &user_real_input_dims: dims info of all user_inputs. | |||||
/// @param [in] vector<pair<string, vector<int64_t>>> &user_input_dims: key:name. value:dynamic dims from option. | /// @param [in] vector<pair<string, vector<int64_t>>> &user_input_dims: key:name. value:dynamic dims from option. | ||||
/// @param [out] vector<uint64_t> &cur_dynamic_dims: real dims gather, where the index of -1. | |||||
/// @param [out] vector<int32_t> &cur_dynamic_dims: real dims gather, where the index of -1. | |||||
/// @return 0: SUCCESS / others: INTERNAL_ERROR | /// @return 0: SUCCESS / others: INTERNAL_ERROR | ||||
/// | /// | ||||
Status GetCurDynamicDims(const vector<vector<int64_t>> &user_real_input_dims, | Status GetCurDynamicDims(const vector<vector<int64_t>> &user_real_input_dims, | ||||
const vector<pair<string, vector<int64_t>>> &user_input_dims, | const vector<pair<string, vector<int64_t>>> &user_input_dims, | ||||
vector<int64_t> &cur_dynamic_dims); | |||||
vector<int32_t> &cur_dynamic_dims); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
@@ -145,7 +145,9 @@ Status HcclTaskInfo::SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciM | |||||
} else { | } else { | ||||
GELOGI("need to reuse follow stream and create new follow stream."); | GELOGI("need to reuse follow stream and create new follow stream."); | ||||
size_t created_stream_num = follow_stream_usage.size(); | size_t created_stream_num = follow_stream_usage.size(); | ||||
hccl_stream_list_ = follow_stream_usage; | |||||
for (const auto &stream : follow_stream_usage) { | |||||
hccl_stream_list_.emplace_back(stream); | |||||
} | |||||
ret = CreateStream(hccl_stream_num - created_stream_num, davinci_model, main_stream_id); | ret = CreateStream(hccl_stream_num - created_stream_num, davinci_model, main_stream_id); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(RT_FAILED, "Create hccl stream failed."); | GELOGE(RT_FAILED, "Create hccl stream failed."); | ||||
@@ -2780,8 +2780,10 @@ Status GraphManager::ParseInputsDims(const std::vector<InputTensorInfo> &input_t | |||||
if (!GetLocalOmgContext().dynamic_node_type.empty()) { | if (!GetLocalOmgContext().dynamic_node_type.empty()) { | ||||
vector<NodePtr> data_nodes; | vector<NodePtr> data_nodes; | ||||
vector<NodePtr> getnext_nosink_nodes; | vector<NodePtr> getnext_nosink_nodes; | ||||
data_nodes = compute_graph_->TryGetExtAttr(kExtAttrDataNodes, data_nodes); | |||||
getnext_nosink_nodes = compute_graph_->TryGetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes); | |||||
data_nodes = GetLocalOmgContext().data_nodes; | |||||
getnext_nosink_nodes = GetLocalOmgContext().getnext_nosink_nodes; | |||||
GELOGD("Data nodes count is %zu, getnext nosink nodes count is %zu.", data_nodes.size(), | |||||
getnext_nosink_nodes.size()); | |||||
if (GetLocalOmgContext().dynamic_node_type == DATA) { | if (GetLocalOmgContext().dynamic_node_type == DATA) { | ||||
if (getnext_nosink_nodes.empty()) { | if (getnext_nosink_nodes.empty()) { | ||||
// just data or data+getnext_sink | // just data or data+getnext_sink | ||||
@@ -26,6 +26,10 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
std::set<std::string> un_compute_attrs = { | |||||
{ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES}, | |||||
}; | |||||
std::string GetCseKey(const NodePtr &node) { | std::string GetCseKey(const NodePtr &node) { | ||||
std::stringstream ss; | std::stringstream ss; | ||||
ss << node->GetType() << "-data-inputs-"; | ss << node->GetType() << "-data-inputs-"; | ||||
@@ -49,7 +53,7 @@ std::string GetCseKey(const NodePtr &node) { | |||||
ss << name << "-"; | ss << name << "-"; | ||||
} | } | ||||
ss << "attrs-" << AttrUtils::GetAllAttrsStr(node->GetOpDesc()); | |||||
ss << "attrs-" << AttrUtils::GetAttrsStrAfterRid(node->GetOpDesc(), un_compute_attrs); | |||||
return ss.str(); | return ss.str(); | ||||
} | } | ||||
@@ -25,31 +25,65 @@ | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
#include "graph/common/omg_util.h" | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
constexpr uint8_t kDataInIndex = 0; | constexpr uint8_t kDataInIndex = 0; | ||||
constexpr uint8_t kDataOutIndex = 0; | constexpr uint8_t kDataOutIndex = 0; | ||||
constexpr uint8_t kCaseArgIndex = 1; | constexpr uint8_t kCaseArgIndex = 1; | ||||
const int kDivisionConst = 2; | |||||
const size_t kNumOfGetnextNode = 1; | |||||
const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case"; | const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case"; | ||||
const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data"; | const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data"; | ||||
const std::string kMultiBatchGetDynamicDimsNode = "ascend_mbatch_get_dynamic_dims_node"; | |||||
const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const"; | const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const"; | ||||
const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex"; | const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex"; | ||||
const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; | const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; | ||||
const char *const kGetNextName = "IteratorV2"; | |||||
} // namespace | } // namespace | ||||
inline bool IsGetNextType(const NodePtr &node) { | |||||
std::string original_type; | |||||
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, | |||||
GELOGW("Get original type failed."); return false); | |||||
return (original_type == kGetNextName); | |||||
} | |||||
Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { | Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { | ||||
GE_IF_BOOL_EXEC(graph == nullptr, GELOGE(FAILED, "Original graph is nullptr"); return FAILED); | |||||
if (graph->GetParentGraph() != nullptr) { | if (graph->GetParentGraph() != nullptr) { | ||||
GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str()); | GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str()); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (!GetLocalOmgContext().need_multi_batch) { | |||||
GELOGI("No need to process_multi for no_train graph."); | |||||
return SUCCESS; | |||||
} | |||||
std::vector<NodePtr> data_nodes; | |||||
std::vector<NodePtr> getnext_nosink_nodes; | |||||
std::vector<NodePtr> getnext_sink_nodes; | |||||
if (multibatch::CheckSequenceOfOptions(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Train_Dynamic] CheckSequenceOfOptions failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (multibatch::UpdateNameOfInputShape(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Train_Dynamic] UpdateNameForInputShapeOfOption failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (multibatch::DeleteIdentityInsertByAdapter(graph) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Train_Dynamic] DeleteIdentityInsertByAdapter failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (!multibatch::InitDynamicParams(batch_shapes_)) { | if (!multibatch::InitDynamicParams(batch_shapes_)) { | ||||
GELOGD("There is no multi-batch options, no need clone multi-batch graph"); | GELOGD("There is no multi-batch options, no need clone multi-batch graph"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (multibatch::CheckNegativeCountOfOptions(batch_shapes_) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Train_Dynamic] Input_shape and dynamic_dims should set correct params."); | |||||
return PARAM_INVALID; | |||||
} | |||||
GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str()); | GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str()); | ||||
GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param"); | GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param"); | ||||
if (CollectIoNodes(graph) != SUCCESS) { | if (CollectIoNodes(graph) != SUCCESS) { | ||||
@@ -66,21 +100,14 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { | |||||
(void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); | (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); | ||||
ComputeGraphPtr branch = MakeShared<ComputeGraph>(graph->GetName()); | ComputeGraphPtr branch = MakeShared<ComputeGraph>(graph->GetName()); | ||||
if (branch == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Create multi-batch graph failed"); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
GE_IF_BOOL_EXEC(branch == nullptr, GELOGE(OUT_OF_MEMORY, "Create multi batch graph failed"); return OUT_OF_MEMORY); | |||||
(void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); | (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); | ||||
graph->InValid(); // Will modify, need topological again. | graph->InValid(); // Will modify, need topological again. | ||||
graph->Swap(*branch); | graph->Swap(*branch); | ||||
if (CreateRootGraph(graph) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
if (CreateSubgraphs(graph, branch) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET(CreateRootGraph(graph), "Construct root graph failed."); | |||||
GE_CHK_STATUS_RET(CreateOriGraph(branch), "Construct original graph failed.") | |||||
GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed."); | |||||
GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); | GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); | ||||
GELOGD("MultiBatchClonePass Leave"); | GELOGD("MultiBatchClonePass Leave"); | ||||
@@ -95,9 +122,13 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { | |||||
/// | /// | ||||
Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { | Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
if (!GetLocalOmgContext().dynamic_node_type.empty() && IsGetNextType(node)) { | |||||
all_data_nodes_.emplace_back(node); | |||||
GE_CHK_STATUS_RET(InitParamsOfGetNext(node), "Init params of %s failed.", node->GetName().c_str()); | |||||
} | |||||
if (node->GetType() == DATA) { | if (node->GetType() == DATA) { | ||||
all_data_nodes_.emplace_back(node); | all_data_nodes_.emplace_back(node); | ||||
} else if (node->GetType() == CONSTANT) { | |||||
} else if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { | |||||
all_const_nodes_.emplace_back(node); | all_const_nodes_.emplace_back(node); | ||||
} else if (node->GetType() == NETOUTPUT) { | } else if (node->GetType() == NETOUTPUT) { | ||||
all_output_nodes_.emplace_back(node); | all_output_nodes_.emplace_back(node); | ||||
@@ -114,10 +145,16 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { | |||||
} | } | ||||
int64_t data_index = 0; | int64_t data_index = 0; | ||||
size_t getnext_node_count = 0; | |||||
for (size_t i = 0; i < all_data_nodes_.size(); ++i) { | for (size_t i = 0; i < all_data_nodes_.size(); ++i) { | ||||
if (IsGetNextType(all_data_nodes_[i])) { | |||||
// just one getnext node in graph | |||||
getnext_node_count++; | |||||
continue; | |||||
} | |||||
const auto &op_desc = all_data_nodes_[i]->GetOpDesc(); | const auto &op_desc = all_data_nodes_[i]->GetOpDesc(); | ||||
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) { | if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) { | ||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i); | |||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i - getnext_node_count); | |||||
} | } | ||||
} | } | ||||
@@ -133,7 +170,43 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { | |||||
"Remove edge failed"); | "Remove edge failed"); | ||||
} | } | ||||
} | } | ||||
GELOGD("Data count is %zu, const count is %zu, getnext count is %zu, output count is %zu, direct out count is %zu.", | |||||
all_data_nodes_.size(), all_const_nodes_.size(), getnext_node_count, all_output_nodes_.size(), | |||||
direct_output_.size()); | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) { | |||||
data_count_from_getnext_ = 0; | |||||
getnext_sink_dynamic_dims_ = false; | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
data_count_from_getnext_ = node->GetOpDesc()->GetOutputsSize(); | |||||
if (GetLocalOmgContext().dynamic_node_type == GETNEXT) { | |||||
data_count_from_getnext_ = data_count_from_getnext_ / kDivisionConst; | |||||
for (size_t i = 0; i < data_count_from_getnext_; ++i) { | |||||
GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(i); | |||||
GELOGD("The %zu data shape from getnext sink is %s.", i, | |||||
formats::JoinToString(output_desc.GetShape().GetDims()).c_str()); | |||||
const auto &dims = output_desc.GetShape().GetDims(); | |||||
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) {return val >= 0; })) { | |||||
GELOGD("The %zu data from %s is static.", i, node->GetName().c_str()); | |||||
} else { | |||||
getnext_sink_dynamic_dims_ = true; | |||||
GELOGD("Dynamic dims in the pattern of getnext sink."); | |||||
} | |||||
} | |||||
} | |||||
if (node->GetOutControlAnchor() != nullptr) { | |||||
for (const auto &peer_in_control_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { | |||||
NodePtr next_node = peer_in_control_anchor->GetOwnerNode(); | |||||
GE_CHECK_NOTNULL(next_node); | |||||
if (next_node->GetType() == CONSTANTOP) { | |||||
out_control_nodes_.insert(next_node); | |||||
GELOGD("Control edge: %s connect with %s.", node->GetName().c_str(), next_node->GetName().c_str()); | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -144,7 +217,11 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { | Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { | ||||
GELOGD("Start create root graph of %s.", graph->GetName().c_str()); | |||||
uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size(); | uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size(); | ||||
if (data_count_from_getnext_ != 0) { | |||||
input_num = input_num + data_count_from_getnext_ - kNumOfGetnextNode; | |||||
} | |||||
uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize(); | uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize(); | ||||
OpDescBuilder op_builder(kMultiBatchCaseNode, CASE); | OpDescBuilder op_builder(kMultiBatchCaseNode, CASE); | ||||
@@ -185,6 +262,10 @@ Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { | |||||
op_desc->GetName().c_str()); | op_desc->GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (!AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to add insert attr on case node %s", op_desc->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed"); | GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed"); | ||||
GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed"); | GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed"); | ||||
@@ -202,7 +283,7 @@ Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { | |||||
/// @param [in] NodePtr node: index data node. | /// @param [in] NodePtr node: index data node. | ||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node) { | |||||
Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node) { | |||||
const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchDataNode, DATA); | const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchDataNode, DATA); | ||||
if (data_desc == nullptr) { | if (data_desc == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); | GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); | ||||
@@ -220,11 +301,12 @@ Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, No | |||||
} | } | ||||
size_t data_index = all_data_nodes_.size(); | size_t data_index = all_data_nodes_.size(); | ||||
data_index = data_count_from_getnext_ != 0 ? data_index - kNumOfGetnextNode : data_index; | |||||
(void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index); | (void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index); | ||||
(void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true); | (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true); | ||||
node = graph->AddNode(data_desc); | |||||
if (node == nullptr) { | |||||
shape_node = graph->AddNode(data_desc); | |||||
if (shape_node == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); | GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); | ||||
return OUT_OF_MEMORY; | return OUT_OF_MEMORY; | ||||
} | } | ||||
@@ -286,15 +368,19 @@ Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, N | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { | Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { | ||||
// Data --> MapIndex --> Case | |||||
NodePtr data_node; | |||||
GE_CHK_STATUS_RET(CreateIndexDataNode(graph, data_node), "Create data node failed"); | |||||
// Data/GetDynamicDims --> MapIndex --> Case | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
GE_CHK_STATUS_RET(CreateIndexDataNode(graph, shape_node_), "Create data node failed"); | |||||
} else { | |||||
GE_CHK_STATUS_RET(CreateGetDynamicDimsNode(graph, shape_node_), "Create get dynamic dims node failed"); | |||||
} | |||||
NodePtr const_node; | NodePtr const_node; | ||||
GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed"); | GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed"); | ||||
GELOGD("Shape node name is %s, type is %s, const node name is %s.", shape_node_->GetName().c_str(), | |||||
shape_node_->GetType().c_str(), const_node->GetName().c_str()); | |||||
OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex"); | OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex"); | ||||
op_builder.AddInput("x", data_node->GetOpDesc()->GetOutputDesc(0)) | |||||
op_builder.AddInput("x", shape_node_->GetOpDesc()->GetOutputDesc(0)) | |||||
.AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0)) | .AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0)) | ||||
.AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32)); | .AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32)); | ||||
@@ -309,8 +395,10 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { | |||||
return OUT_OF_MEMORY; | return OUT_OF_MEMORY; | ||||
} | } | ||||
if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", data_node->GetName().c_str(), | |||||
GE_CHK_STATUS_RET(AddAttrForGetDynamicDims(shape_node_), "Failed to add attr for %s.", | |||||
shape_node_->GetName().c_str()); | |||||
if (GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", shape_node_->GetName().c_str(), | |||||
index_node->GetName().c_str()); | index_node->GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -328,6 +416,120 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchClonePass::CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node) { | |||||
const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchGetDynamicDimsNode, GETDYNAMICDIMS); | |||||
if (data_desc == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Create multi-batch get dynamic dims node failed"); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
// input of GetDynamicDims is shape_of_each_data, output is gear_info | |||||
for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { | |||||
size_t input_shape_dims = GetLocalOmgContext().user_input_dims.at(i).second.size(); | |||||
// add input desc without GeShape for const input, value of input_shape is 1 transferred by adapter | |||||
if (input_shape_dims == 1 && GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) { | |||||
GeTensorDesc tensor_desc; | |||||
tensor_desc.SetFormat(FORMAT_ND); | |||||
tensor_desc.SetDataType(DT_INT32); | |||||
auto ret = data_desc->AddInputDesc(tensor_desc); | |||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); | |||||
return FAILED); | |||||
continue; | |||||
} | |||||
GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(input_shape_dims)}), FORMAT_ND, DT_INT32); | |||||
auto ret = data_desc->AddInputDesc(tensor_desc); | |||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); | |||||
return FAILED); | |||||
} | |||||
GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(batch_shapes_.at(0).size())}), FORMAT_ND, DT_INT32); | |||||
auto ret = data_desc->AddOutputDesc(tensor_desc); | |||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data"); | |||||
return FAILED); | |||||
(void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true); | |||||
shape_node = graph->AddNode(data_desc); | |||||
if (shape_node == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Create multi-batch dynamic dims node failed"); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchClonePass::AddAttrForGetDynamicDims(const NodePtr &shape_node) { | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
GELOGD("No need to add attr when not insert get dynamic dims node."); | |||||
return SUCCESS; | |||||
} | |||||
GELOGD("Add attr for :%s, type is %s:", shape_node->GetName().c_str(), shape_node->GetType().c_str()); | |||||
if (!AttrUtils::SetInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_DATA_COUNT, data_count_from_getnext_)) { | |||||
GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_DATA_COUNT failed"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
vector<int64_t> shape_info; | |||||
for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { | |||||
if (GetLocalOmgContext().user_input_dims.at(i).second.size() == 1 && | |||||
GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) { | |||||
shape_info.emplace_back(0); | |||||
continue; | |||||
} | |||||
shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.size()); | |||||
for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) { | |||||
shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.at(j)); | |||||
} | |||||
} | |||||
if (!AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_SHAPE_INFO, shape_info)) { | |||||
GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_SHAPE_INFO failed"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchClonePass::LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node) { | |||||
GELOGD("Start relink shape anchor of %s to %s.", getnext_node->GetName().c_str(), shape_node->GetName().c_str()); | |||||
size_t input_index = 0; | |||||
size_t data_count = getnext_node->GetAllOutDataAnchors().size() / kDivisionConst; | |||||
for (size_t out_index = data_count; out_index < getnext_node->GetAllOutDataAnchors().size(); ++out_index, | |||||
++input_index) { | |||||
GELOGD("Start add %s of %zu out_anchor to %s of %zu in_anchor.", getnext_node->GetName().c_str(), out_index, | |||||
shape_node->GetName().c_str(), input_index); | |||||
auto out_data_anchor = getnext_node->GetOutDataAnchor(out_index); | |||||
auto ret = GraphUtils::AddEdge(out_data_anchor, shape_node->GetInDataAnchor(input_index)); | |||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link getnext %s to getdynamicdims %s", | |||||
getnext_node->GetName().c_str(), shape_node->GetName().c_str()); | |||||
return INTERNAL_ERROR); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchClonePass::LinkGetDynamicDimsToNetOutput(const NodePtr &output_node) { | |||||
if (!GetLocalOmgContext().dynamic_node_type.empty()) { | |||||
if (!AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, GetLocalOmgContext().dynamic_dims)) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to set all gears info attr on netoutput %s.", output_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | |||||
if (getnext_sink_dynamic_dims_) { | |||||
GELOGD("Start link %s to %s.", shape_node_->GetName().c_str(), output_node->GetName().c_str()); | |||||
size_t input_index = output_node->GetAllInDataAnchors().size(); | |||||
if (NodeUtils::AppendInputAnchor(output_node, input_index + 1) != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Append input anchor of %s of %zu failed.", output_node->GetName().c_str(), input_index); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
auto ret = GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(kDataOutIndex), | |||||
output_node->GetInDataAnchor(input_index)); | |||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link netoutput %s to getdynamicdims %s", | |||||
output_node->GetName().c_str(), shape_node_->GetName().c_str()); | |||||
return INTERNAL_ERROR); | |||||
if (!AttrUtils::SetBool(output_node->GetOpDesc(), ATTR_GETNEXT_SINK_DYNMAIC, true)) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to set getnext sink dynamic attr on netoutput %s.", | |||||
output_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Create input node for root graph. | /// @brief Create input node for root graph. | ||||
@@ -337,8 +539,10 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { | |||||
Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { | Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { | ||||
// Data --> Case | // Data --> Case | ||||
std::vector<NodePtr> all_data_nodes; | std::vector<NodePtr> all_data_nodes; | ||||
const size_t arg_index = kCaseArgIndex; | |||||
for (size_t i = 0; i < all_data_nodes_.size(); ++i) { | |||||
size_t case_input_index = kCaseArgIndex; | |||||
NodePtr getnext_node = nullptr; | |||||
size_t input_index_of_getnext = 0; | |||||
for (size_t i = 0; i < all_data_nodes_.size(); ++i, ++case_input_index) { | |||||
const auto &node = all_data_nodes_[i]; | const auto &node = all_data_nodes_[i]; | ||||
const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); | const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); | ||||
if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
@@ -353,22 +557,60 @@ Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { | |||||
op_desc->SetName(node->GetName()); | op_desc->SetName(node->GetName()); | ||||
const NodePtr &data = graph->AddNode(op_desc); | const NodePtr &data = graph->AddNode(op_desc); | ||||
GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); | GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); | ||||
if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", | |||||
data->GetName().c_str(), case_node_->GetName().c_str()); | |||||
return FAILED; | |||||
if (IsGetNextType(node)) { | |||||
getnext_node = data; | |||||
input_index_of_getnext = case_input_index; | |||||
case_input_index = case_input_index + data_count_from_getnext_; | |||||
continue; | |||||
} else { | |||||
if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(case_input_index)) != | |||||
GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", data->GetName().c_str(), | |||||
case_node_->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | } | ||||
if (SetMaxShapeToData(data) != SUCCESS) { | |||||
if (SetMaxShape(data) != SUCCESS) { | |||||
GELOGE(FAILED, "Set max shape of %s failed.", data->GetName().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
all_data_nodes.emplace_back(data); | all_data_nodes.emplace_back(data); | ||||
} | } | ||||
if (getnext_node != nullptr) { | |||||
if (LinkEdgeForGetNext(getnext_node, input_index_of_getnext) != SUCCESS) { | |||||
GELOGE(FAILED, "Failed to link edge for %s.", getnext_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (SetMaxShape(getnext_node) != SUCCESS) { | |||||
GELOGE(FAILED, "Set max shape of %s failed.", getnext_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
all_data_nodes.emplace_back(getnext_node); | |||||
} | |||||
all_data_nodes_.swap(all_data_nodes); | all_data_nodes_.swap(all_data_nodes); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchClonePass::LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index) { | |||||
GELOGD("Start link edge for %s, which is the %zu input of %s.", getnext_node->GetName().c_str(), | |||||
case_input_index, case_node_->GetName().c_str()); | |||||
for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++case_input_index) { | |||||
if (GraphUtils::AddEdge(getnext_node->GetOutDataAnchor(out_index), | |||||
case_node_->GetInDataAnchor(case_input_index)) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Failed to add data edge between %zu Data:%s to %zu Case:%s", out_index, | |||||
getnext_node->GetName().c_str(), case_input_index, case_node_->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
if (getnext_sink_dynamic_dims_) { | |||||
GE_CHK_STATUS_RET(LinkGetNextToGetDynamicDims(getnext_node, shape_node_), "Failed to add link for %s.", | |||||
shape_node_->GetName().c_str()); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Create Const node for root graph. | /// @brief Create Const node for root graph. | ||||
@@ -378,7 +620,11 @@ Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { | |||||
Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { | Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { | ||||
// Const --> Case | // Const --> Case | ||||
std::vector<NodePtr> all_const_nodes; | std::vector<NodePtr> all_const_nodes; | ||||
const size_t arg_index = kCaseArgIndex + all_data_nodes_.size(); | |||||
size_t arg_index = kCaseArgIndex + all_data_nodes_.size(); | |||||
if (data_count_from_getnext_ != 0) { | |||||
arg_index = arg_index + data_count_from_getnext_ - kNumOfGetnextNode; | |||||
} | |||||
for (size_t i = 0; i < all_const_nodes_.size(); ++i) { | for (size_t i = 0; i < all_const_nodes_.size(); ++i) { | ||||
const auto &node = all_const_nodes_[i]; | const auto &node = all_const_nodes_[i]; | ||||
const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); | const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); | ||||
@@ -395,15 +641,33 @@ Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { | |||||
const NodePtr &data = graph->AddNode(op_desc); | const NodePtr &data = graph->AddNode(op_desc); | ||||
GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); | GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); | ||||
if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { | if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { | ||||
GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", | |||||
data->GetName().c_str(), case_node_->GetName().c_str()); | |||||
GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", data->GetName().c_str(), | |||||
case_node_->GetName().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
all_const_nodes.emplace_back(data); | all_const_nodes.emplace_back(data); | ||||
} | } | ||||
ChangeConstToData(); | |||||
all_const_nodes_.swap(all_const_nodes); | |||||
return SUCCESS; | |||||
} | |||||
void MultiBatchClonePass::ChangeConstToData() { | |||||
size_t data_index = all_data_nodes_.size(); | size_t data_index = all_data_nodes_.size(); | ||||
if (data_count_from_getnext_ != 0) { | |||||
data_index = data_index + data_count_from_getnext_ - kNumOfGetnextNode; | |||||
} | |||||
for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data. | for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data. | ||||
auto &const_node = all_const_nodes_[i]; | |||||
bool need_change_type = true; | |||||
if (out_control_nodes_.find(const_node) != out_control_nodes_.end()) { | |||||
GELOGD("No need to change %s to data type.", const_node->GetName().c_str()); | |||||
need_change_type = false; | |||||
break; | |||||
} | |||||
if (!need_change_type) { | |||||
continue; | |||||
} | |||||
const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc(); | const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc(); | ||||
op_desc->SetType(DATA); | op_desc->SetType(DATA); | ||||
(void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight. | (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight. | ||||
@@ -413,9 +677,6 @@ Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { | |||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); | ||||
(void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1); | (void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1); | ||||
} | } | ||||
all_const_nodes_.swap(all_const_nodes); | |||||
return SUCCESS; | |||||
} | } | ||||
/// | /// | ||||
@@ -461,7 +722,8 @@ Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
GE_CHK_STATUS_RET(LinkGetDynamicDimsToNetOutput(node), "Failed to add edge between %s to netoutput: %s.", | |||||
shape_node_->GetName().c_str(), output->GetName().c_str()); | |||||
all_output_nodes_.clear(); | all_output_nodes_.clear(); | ||||
all_output_nodes_.emplace_back(node); | all_output_nodes_.emplace_back(node); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -473,34 +735,69 @@ Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) { | |||||
/// @param [in] const NodePtr &data: data in Root/Case graph. | /// @param [in] const NodePtr &data: data in Root/Case graph. | ||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { | |||||
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | |||||
auto data_name = data->GetName(); | |||||
Status MultiBatchClonePass::SetMaxShape(const NodePtr &data) { | |||||
GELOGD("Start set max shape for %s.", data->GetName().c_str()); | |||||
if (!IsGetNextType(data)) { | |||||
if (SetMaxShapeToData(data, kDataOutIndex) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
} else { | |||||
for (size_t out_anchor_index = 0; out_anchor_index < data_count_from_getnext_; ++out_anchor_index) { | |||||
if (SetMaxShapeToData(data, out_anchor_index) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index) { | |||||
GELOGD("Start update max shape of %s, %zu output.", node->GetName().c_str(), out_anchor_index); | |||||
auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); | |||||
string data_name = node->GetName(); | |||||
if (IsGetNextType(node)) { | |||||
data_name.append("_").append(std::to_string(out_anchor_index)); | |||||
} | |||||
GELOGD("Update max shape of %s, shape dims is %s.", data_name.c_str(), | |||||
formats::JoinToString(data_shape.GetDims()).c_str()); | |||||
const auto &dims = data_shape.GetDims(); | const auto &dims = data_shape.GetDims(); | ||||
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | |||||
return SUCCESS; | |||||
if (!IsGetNextType(node)) { | |||||
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | |||||
GELOGD("No need to do anything for static data."); | |||||
return SUCCESS; | |||||
} | |||||
} else { | |||||
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | |||||
if (getnext_sink_dynamic_dims_) { | |||||
// need to update shape of Shape_node when getnext node has dynamic data | |||||
GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node"); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} | } | ||||
(void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); | |||||
(void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); | |||||
GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); | |||||
GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex)); | |||||
std::vector<std::string> input_dims_str; | std::vector<std::string> input_dims_str; | ||||
for (size_t i = 0; i < batch_shapes_.size(); ++i) { | for (size_t i = 0; i < batch_shapes_.size(); ++i) { | ||||
auto shape = data_shape; | auto shape = data_shape; | ||||
auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); | auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", data->GetName().c_str()); | |||||
GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", node->GetName().c_str()); | |||||
return ret; | return ret; | ||||
} | } | ||||
tensor.SetShape(shape); | tensor.SetShape(shape); | ||||
int64_t tensor_size = 0; | int64_t tensor_size = 0; | ||||
(void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); | (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); | ||||
string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + | string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + | ||||
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + | |||||
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + node->GetName() + ":" + | |||||
std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + | std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + | ||||
formats::JoinToString(tensor.GetShape().GetDims()); | formats::JoinToString(tensor.GetShape().GetDims()); | ||||
input_dims_str.emplace_back(input_str); | input_dims_str.emplace_back(input_str); | ||||
} | } | ||||
(void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); | |||||
(void)AttrUtils::SetListStr(node->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); | |||||
size_t max_shape_index = 0; | size_t max_shape_index = 0; | ||||
int64_t max_size = 0; | int64_t max_size = 0; | ||||
@@ -519,18 +816,72 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { | |||||
max_shape_index = i; | max_shape_index = i; | ||||
} | } | ||||
} | } | ||||
return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), node, data_shape, out_anchor_index); | |||||
} | |||||
return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), data, data_shape); | |||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Set max shape to Data/GetNext node in root graph. | |||||
/// @param [in] const std::vector<int64_t> &shapes: dims of shape. | |||||
/// @param [in] const NodePtr &data: data in Root/Case graph. | |||||
/// @param [in] GeShape &data_shape: dims of data node. | |||||
/// @param [in] size_t out_anchor_index: out anchor index of data node. | |||||
/// @return 0: SUCCESS / others: FAILED | |||||
/// | |||||
Status MultiBatchClonePass::SetShapeToData(const std::vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape, | |||||
size_t out_anchor_index) { | |||||
GELOGD("Start set shape to %zu out of %s.", out_anchor_index, data->GetName().c_str()); | |||||
if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to calculate the batched shape for data node %s, the shapes may not match", | |||||
data->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (NodeUtils::UpdateOutputShape(*data, out_anchor_index, data_shape) != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (!IsGetNextType(data)) { | |||||
if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} else { | |||||
if (getnext_sink_dynamic_dims_) { | |||||
// need to update shape of Shape_node when getnext_sink_dynamic | |||||
GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(data, out_anchor_index), "Failed to update shape of shape node"); | |||||
} | |||||
} | |||||
GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(), | |||||
formats::ShapeToString(data_shape).c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchClonePass::UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index) { | |||||
GELOGD("Start update output shape of shape node insert by adapter, which is the %zu out of %s.", out_anchor_index, | |||||
node->GetName().c_str()); | |||||
auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); | |||||
size_t shape_index = out_anchor_index + (node->GetAllOutDataAnchors().size() / kDivisionConst); | |||||
GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(shape_index); | |||||
std::vector<int64_t> output_dims = {static_cast<int64_t>(data_shape.GetDims().size())}; | |||||
GeShape output_shape(output_dims); | |||||
output_desc.SetShape(output_shape); | |||||
if (node->GetOpDesc()->UpdateOutputDesc(shape_index, output_desc) != SUCCESS) { | |||||
GELOGE(FAILED, "Update output desc fail."); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | } | ||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Update Data node in Subgraph. | /// @brief Update Data node in Subgraph. | ||||
/// @param [in] const NodePtr &data: data in Subgraph. | /// @param [in] const NodePtr &data: data in Subgraph. | ||||
/// @param [in] size_t index: The batch index. | |||||
/// @param [in] size_t batch_index: The batch index. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) { | |||||
Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t batch_index) { | |||||
int node_index = -1; | int node_index = -1; | ||||
if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) { | if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) { | ||||
GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str()); | GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str()); | ||||
@@ -545,6 +896,8 @@ Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index | |||||
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | ||||
const auto &dims = data_shape.GetDims(); | const auto &dims = data_shape.GetDims(); | ||||
GELOGD("Start update shape of %s , batch index is %zu, dims is %s.", data->GetName().c_str(), batch_index, | |||||
formats::JoinToString(dims).c_str()); | |||||
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -559,35 +912,77 @@ Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index | |||||
} | } | ||||
auto parent_name = data_name.substr(0, pos); | auto parent_name = data_name.substr(0, pos); | ||||
return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(index), data, data_shape); | |||||
return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(batch_index), data, data_shape, kDataOutIndex); | |||||
} | } | ||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Set max shape to Data node in root graph. | |||||
/// @param [in] const std::vector<int64_t> &shapes: dims of shape. | |||||
/// @param [in] const NodePtr &data: data in Root/Case graph. | |||||
/// @param [in] GeShape &data_shape: dims of data node. | |||||
/// @return 0: SUCCESS / others: FAILED | |||||
/// | |||||
Status MultiBatchClonePass::SetShapeToData(const vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape) { | |||||
// must not be error, the calc result has been checked in function InsertSwitchNForData | |||||
if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) { | |||||
return INTERNAL_ERROR; | |||||
Status MultiBatchClonePass::CreateOriGraph(const ComputeGraphPtr &graph) { | |||||
if (data_count_from_getnext_ == 0) { | |||||
GELOGD("No need to change original graph without getnext node."); | |||||
return SUCCESS; | |||||
} | } | ||||
if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
GELOGD("Start change original graph: %s when exit getnext node.", graph->GetName().c_str()); | |||||
size_t data_index = all_data_nodes_.size() - kNumOfGetnextNode; | |||||
for (const auto &node : graph->GetDirectNode()) { | |||||
if (IsGetNextType(node)) { | |||||
for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++data_index) { | |||||
auto out_data_anchor = node->GetOutDataAnchor(out_index); | |||||
GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); | |||||
NodePtr data_node = CreateDataNode(graph, out_data_anchor, data_index); | |||||
GE_IF_BOOL_EXEC(data_node == nullptr, GELOGE(INTERNAL_ERROR, "Create %zu data node failed.", | |||||
out_data_anchor->GetIdx()); return INTERNAL_ERROR); | |||||
for (auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
GE_IF_BOOL_EXEC(in_anchor == nullptr, continue); | |||||
NodePtr dst_node = in_anchor->GetOwnerNode(); | |||||
if (GraphUtils::RemoveEdge(out_data_anchor, in_anchor) != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to remove edge between %s to %s", node->GetName().c_str(), | |||||
dst_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(in_anchor->GetIdx())) != | |||||
GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to add edge between %s to %s", data_node->GetName().c_str(), | |||||
dst_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | |||||
} | |||||
if (graph->RemoveNode(node) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Remove node %s failed!", node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
break; | |||||
} | |||||
} | } | ||||
return SUCCESS; | |||||
} | |||||
if (NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape) != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
NodePtr MultiBatchClonePass::CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, | |||||
size_t data_index) { | |||||
size_t out_anchor_index = out_data_anchor->GetIdx(); | |||||
std::string node_name = out_data_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor_index); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, DATA); | |||||
if (op_desc == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Create data node failed."); | |||||
return nullptr; | |||||
} | } | ||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); | |||||
GELOGI("Update %s input/output shape to %s", data->GetName().c_str(), formats::ShapeToString(data_shape).c_str()); | |||||
return SUCCESS; | |||||
OpDescPtr getnext_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | |||||
if (getnext_op_desc == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Op desc of %s is nullptr.", out_data_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
if (op_desc->AddInputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Add %s input desc failed.", op_desc->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
if (op_desc->AddOutputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Add %s output desc failed.", op_desc->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
NodePtr data_node = graph->AddNode(op_desc); | |||||
GELOGD("Success create %s node.", data_node->GetName().c_str()); | |||||
return data_node; | |||||
} | } | ||||
/// | /// | ||||
@@ -598,17 +993,14 @@ Status MultiBatchClonePass::SetShapeToData(const vector<int64_t> &shapes, const | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) { | Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) { | ||||
GELOGD("Start create subgraphs for %s.", graph->GetName().c_str()); | |||||
const auto &op_desc = case_node_->GetOpDesc(); | const auto &op_desc = case_node_->GetOpDesc(); | ||||
for (size_t i = 0; i < batch_shapes_.size(); ++i) { | for (size_t i = 0; i < batch_shapes_.size(); ++i) { | ||||
std::vector<NodePtr> input_nodes; | std::vector<NodePtr> input_nodes; | ||||
std::vector<NodePtr> output_nodes; | std::vector<NodePtr> output_nodes; | ||||
const std::string postfix = kMultiBatchNodePostfix + std::to_string(i); | const std::string postfix = kMultiBatchNodePostfix + std::to_string(i); | ||||
ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes); | ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes); | ||||
if (subgraph == nullptr) { | |||||
GELOGE(FAILED, "Create multi-batch case node failed"); | |||||
return FAILED; | |||||
} | |||||
GE_IF_BOOL_EXEC(subgraph == nullptr, GELOGE(FAILED, "Create multi-batch case node failed"); return FAILED); | |||||
subgraph->SetName("Batch_" + std::to_string(i)); | subgraph->SetName("Batch_" + std::to_string(i)); | ||||
subgraph->SetParentNode(case_node_); | subgraph->SetParentNode(case_node_); | ||||
subgraph->SetParentGraph(graph); | subgraph->SetParentGraph(graph); | ||||
@@ -621,6 +1013,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
op_desc->AddSubgraphName(key_name); | op_desc->AddSubgraphName(key_name); | ||||
op_desc->SetSubgraphInstanceName(i, subgraph->GetName()); | op_desc->SetSubgraphInstanceName(i, subgraph->GetName()); | ||||
GELOGD("The %s has %zu input, %zu output.", subgraph->GetName().c_str(), input_nodes.size(), output_nodes.size()); | |||||
for (const auto &data : input_nodes) { | for (const auto &data : input_nodes) { | ||||
GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str()); | GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str()); | ||||
} | } | ||||
@@ -666,6 +1059,7 @@ Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { | Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { | ||||
GELOGD("Start prune direct output."); | |||||
const auto &func_desc = case_node_->GetOpDesc(); | const auto &func_desc = case_node_->GetOpDesc(); | ||||
uint32_t unused_num = 0; | uint32_t unused_num = 0; | ||||
uint32_t output_num = func_desc->GetOutputsSize(); | uint32_t output_num = func_desc->GetOutputsSize(); | ||||
@@ -710,6 +1104,7 @@ Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { | |||||
/// | /// | ||||
Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) { | Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) { | ||||
if (unused_num == 0) { | if (unused_num == 0) { | ||||
GELOGD("No need to update output tensor."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -36,6 +36,7 @@ class MultiBatchClonePass : public GraphPass { | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status CollectIoNodes(const ComputeGraphPtr &graph); | Status CollectIoNodes(const ComputeGraphPtr &graph); | ||||
Status InitParamsOfGetNext(const NodePtr &node); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -49,10 +50,12 @@ class MultiBatchClonePass : public GraphPass { | |||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Create index data node for root graph. | /// @brief Create index data node for root graph. | ||||
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | ||||
/// @param [in] NodePtr node: index data node. | |||||
/// @param [in] NodePtr shape_node: index data node, DATA or GETDYNAMICDIMS type. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node); | |||||
Status CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node); | |||||
Status CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -70,6 +73,9 @@ class MultiBatchClonePass : public GraphPass { | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status CreateIndexNode(const ComputeGraphPtr &graph); | Status CreateIndexNode(const ComputeGraphPtr &graph); | ||||
Status AddAttrForGetDynamicDims(const NodePtr &shape_node); | |||||
Status LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node); | |||||
Status LinkGetDynamicDimsToNetOutput(const NodePtr &output_node); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -78,39 +84,54 @@ class MultiBatchClonePass : public GraphPass { | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status CreateInputNode(const ComputeGraphPtr &graph); | Status CreateInputNode(const ComputeGraphPtr &graph); | ||||
Status LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Create Const node for root graph. | |||||
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | |||||
/// @brief Set max shape to Data node in root graph. | |||||
/// @param [in] const NodePtr &data: data in Root/Case graph. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status CreateConstNode(const ComputeGraphPtr &graph); | |||||
Status SetMaxShape(const NodePtr &data); | |||||
Status SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index); | |||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Set max shape to Data/GetNext node in root graph. | |||||
/// @param [in] const std::vector<int64_t> &shapes: dims of shape. | |||||
/// @param [in] const NodePtr &data: data in Root/Case graph. | |||||
/// @param [in] GeShape &data_shape: dims of data node. | |||||
/// @param [in] size_t out_anchor_index: out anchor index of data node. | |||||
/// @return 0: SUCCESS / others: FAILED | |||||
/// | |||||
Status SetShapeToData(const std::vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape, | |||||
size_t out_anchor_index); | |||||
Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Create output node for root graph. | |||||
/// @brief Create Const node for root graph. | |||||
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | ||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status CreateOutputNode(const ComputeGraphPtr &graph); | |||||
Status CreateConstNode(const ComputeGraphPtr &graph); | |||||
void ChangeConstToData(); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Set max shape to Data node in root graph. | |||||
/// @param [in] const NodePtr &data: data in Root/Case graph. | |||||
/// @brief Create output node for root graph. | |||||
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status SetMaxShapeToData(const NodePtr &data); | |||||
Status CreateOutputNode(const ComputeGraphPtr &graph); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Update Data node in Subgraph. | /// @brief Update Data node in Subgraph. | ||||
/// @param [in] const NodePtr &data: data in Subgraph. | /// @param [in] const NodePtr &data: data in Subgraph. | ||||
/// @param [in] size_t index: The batch index. | |||||
/// @param [in] size_t batch_index: The batch index. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status UpdateSubgraphData(const NodePtr &data, size_t index); | |||||
Status UpdateSubgraphData(const NodePtr &data, size_t batch_index); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -122,13 +143,12 @@ class MultiBatchClonePass : public GraphPass { | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Set max shape to Data node in root graph. | |||||
/// @param [in] const std::vector<int64_t> &shapes: dims of shape. | |||||
/// @param [in] const NodePtr &data: data in Root/Case graph. | |||||
/// @param [in] GeShape &data_shape: dims of data node. | |||||
/// @brief Create nodes for root graph. | |||||
/// @param [in] const ComputeGraphPtr &graph: Original graph. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status SetShapeToData(const std::vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape); | |||||
Status CreateOriGraph(const ComputeGraphPtr &graph); | |||||
NodePtr CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, size_t data_index); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -168,6 +188,10 @@ class MultiBatchClonePass : public GraphPass { | |||||
std::map<string, vector<vector<int64_t>>> data_to_dynamic_info_; | std::map<string, vector<vector<int64_t>>> data_to_dynamic_info_; | ||||
NodePtr case_node_; | NodePtr case_node_; | ||||
size_t data_count_from_getnext_ = 0; | |||||
bool getnext_sink_dynamic_dims_ = false; | |||||
NodePtr shape_node_; | |||||
std::set<NodePtr> out_control_nodes_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_PASSES_MULTI_BATCH_CLONE_PASS_H_ | #endif // GE_GRAPH_PASSES_MULTI_BATCH_CLONE_PASS_H_ |
@@ -204,6 +204,10 @@ Status UnusedArgsCleanPass::RemoveInputTensor(const map<ComputeGraphPtr, map<uin | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed"); | GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed"); | ||||
GELOGI("Remove edge: %s %s", out_node->GetName().c_str(), func_node->GetName().c_str()); | GELOGI("Remove edge: %s %s", out_node->GetName().c_str(), func_node->GetName().c_str()); | ||||
if (out_node->GetInDataNodes().size() == 0 && out_node->GetOutAllNodes().size() == 0) { | |||||
GE_CHK_GRAPH_STATUS_RET(out_node->GetOwnerComputeGraph()->RemoveNode(out_node), "Remove node failed: %s", | |||||
out_node->GetName().c_str()); | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -1692,13 +1692,11 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { | |||||
} | } | ||||
Status ProcessMultiBatch(ComputeGraphPtr &graph) { | Status ProcessMultiBatch(ComputeGraphPtr &graph) { | ||||
if (GetLocalOmgContext().dynamic_node_type.empty()) { | |||||
const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); | |||||
if (multi_batch_with_switchn == nullptr) { | |||||
PassManager pass_manager; | |||||
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); | |||||
return pass_manager.Run(graph); | |||||
} | |||||
const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); | |||||
if (multi_batch_with_switchn == nullptr) { | |||||
PassManager pass_manager; | |||||
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); | |||||
return pass_manager.Run(graph); | |||||
} | } | ||||
if (!GetLocalOmgContext().need_multi_batch) { | if (!GetLocalOmgContext().need_multi_batch) { | ||||
GELOGI("No need to process_multi for no_train graph."); | GELOGI("No need to process_multi for no_train graph."); | ||||
@@ -99,9 +99,8 @@ Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_n | |||||
} | } | ||||
GELOGI("Data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(), | GELOGI("Data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(), | ||||
getnext_nosink_nodes.size(), getnext_sink_nodes.size()); | getnext_nosink_nodes.size(), getnext_sink_nodes.size()); | ||||
GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrDataNodes, data_nodes), GELOGW("Set data nodes attr failed.");) | |||||
GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes), | |||||
GELOGW("Set getnext nosink nodes attr failed.");) | |||||
GetLocalOmgContext().data_nodes = data_nodes; | |||||
GetLocalOmgContext().getnext_nosink_nodes = getnext_nosink_nodes; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -26,6 +26,7 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "framework/common/fmk_error_codes.h" | #include "framework/common/fmk_error_codes.h" | ||||
#include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
#include "graph/node.h" | |||||
using domi::DOMI_TENSOR_ND; | using domi::DOMI_TENSOR_ND; | ||||
using domi::DOMI_TENSOR_RESERVED; | using domi::DOMI_TENSOR_RESERVED; | ||||
@@ -120,6 +121,8 @@ struct OmgContext { | |||||
std::vector<std::vector<int64_t>> user_real_input_dims; | std::vector<std::vector<int64_t>> user_real_input_dims; | ||||
std::vector<int64_t> cur_dynamic_dims; | std::vector<int64_t> cur_dynamic_dims; | ||||
bool need_multi_batch = false; | bool need_multi_batch = false; | ||||
std::vector<NodePtr> data_nodes; | |||||
std::vector<NodePtr> getnext_nosink_nodes; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -1 +1 @@ | |||||
Subproject commit 44bcbb5ea25ada1a5393aa4c7f554d40b6859b18 | |||||
Subproject commit fe37bc343ea52c76d35e9e9ec83cea0151bfa900 |
@@ -1 +1 @@ | |||||
Subproject commit 5b93b050dd7ca5b77c3001a790031d877fa10956 | |||||
Subproject commit 336cd3107253d3fe41cfb9fec2db62b5f3d8a33b |
@@ -627,6 +627,7 @@ set(PASS_TEST_FILES | |||||
"graph/passes/net_output_pass_unittest.cc" | "graph/passes/net_output_pass_unittest.cc" | ||||
"graph/passes/no_use_reshape_remove_pass_unittest.cc" | "graph/passes/no_use_reshape_remove_pass_unittest.cc" | ||||
"graph/passes/infershape_pass_unittest.cc" | "graph/passes/infershape_pass_unittest.cc" | ||||
"graph/passes/multi_batch_clone_pass_unittest.cc" | |||||
) | ) | ||||
set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
@@ -32,6 +32,18 @@ class UtestDavinciModel : public testing::Test { | |||||
void SetUp() {} | void SetUp() {} | ||||
void TearDown() {} | void TearDown() {} | ||||
public: | |||||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
for (auto i = 0; i < in_num; ++i) { | |||||
op_desc->AddInputDesc(test_desc); | |||||
} | |||||
for (auto i = 0; i < out_num; ++i) { | |||||
op_desc->AddOutputDesc(test_desc); | |||||
} | |||||
return graph->AddNode(op_desc); | |||||
} | |||||
}; | }; | ||||
TEST_F(UtestDavinciModel, init_success) { | TEST_F(UtestDavinciModel, init_success) { | ||||
@@ -324,5 +336,94 @@ TEST_F(UtestDavinciModel, SyncVarData_test) { | |||||
EXPECT_NE(model.SyncVarData(), SUCCESS); | EXPECT_NE(model.SyncVarData(), SUCCESS); | ||||
} | } | ||||
TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ1) { | |||||
DavinciModel model(0, nullptr); | |||||
model.ge_model_ = make_shared<GeModel>(); | |||||
ComputeGraphPtr graph = make_shared<ComputeGraph>("default"); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
OpDescPtr op_output = CreateOpDesc("output_ascend_mbatch_batch_1", NETOUTPUT); | |||||
op_output->AddInputDesc(tensor); | |||||
op_output->SetInputOffset({1024}); | |||||
NodePtr node_output = graph->AddNode(op_output); | |||||
EXPECT_EQ(model.InitRealSizeAndShapeInfo(graph, node_output), SUCCESS); | |||||
} | |||||
TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ2) { | |||||
DavinciModel model(0, nullptr); | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
OpDescPtr data1 = CreateOpDesc("data1", DATA); | |||||
GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
data1->AddInputDesc(shape_desc); | |||||
data1->AddOutputDesc(shape_desc); | |||||
NodePtr data1_node = graph->AddNode(data1); | |||||
OpDescPtr case_node = CreateOpDesc("case1", CASE); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
case_node->AddInputDesc(tensor); | |||||
case_node->AddOutputDesc(tensor); | |||||
NodePtr case1_node = graph->AddNode(case_node); | |||||
OpDescPtr output = CreateOpDesc("output1", NETOUTPUT); | |||||
output->AddInputDesc(tensor); | |||||
output->SetSrcName( { "case1" } ); | |||||
output->SetSrcIndex( { 0 } ); | |||||
NodePtr output_node = graph->AddNode(output); | |||||
GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), case1_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(case1_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
(void)AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, "1;2;4;8"); | |||||
(void)AttrUtils::SetBool(case_node, ATTR_INSERT_BY_MBATCH, true); | |||||
model.is_getnext_sink_dynamic_ = false; | |||||
model.is_online_infer_dynamic_ = true; | |||||
auto ret = model.InitRealSizeAndShapeInfo(graph, output_node); | |||||
// GetGearAndRealOutShapeInfo without ATTR_NAME_DYNAMIC_OUTPUT_DIMS | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
vector<string> dynamic_output_dims = {"0,0,1,1,0,2,2,0,4,3,0,8"}; | |||||
(void)AttrUtils::SetListStr(output_node->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims); | |||||
ret = model.InitRealSizeAndShapeInfo(graph, output_node); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ3) { | |||||
DavinciModel model(0, nullptr); | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
OpDescPtr data1 = CreateOpDesc("data1", DATA); | |||||
GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
data1->AddInputDesc(shape_desc); | |||||
data1->AddOutputDesc(shape_desc); | |||||
NodePtr data1_node = graph->AddNode(data1); | |||||
OpDescPtr shape_node = CreateOpDesc("ascend_mbatch_get_dynamic_dims_node", GETDYNAMICDIMS); | |||||
GeTensorDesc in_tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc out_tensor(GeShape({4,3}), FORMAT_NCHW, DT_FLOAT); | |||||
shape_node->AddInputDesc(in_tensor); | |||||
shape_node->AddOutputDesc(out_tensor); | |||||
NodePtr get_dynamic_dims_node = graph->AddNode(shape_node); | |||||
OpDescPtr output = CreateOpDesc("output1", NETOUTPUT); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
output->AddInputDesc(tensor); | |||||
output->SetSrcName( { "data1", "ascend_mbatch_get_dynamic_dims_node" } ); | |||||
output->SetSrcIndex( { 0, 1 } ); | |||||
NodePtr output_node = graph->AddNode(output); | |||||
GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(get_dynamic_dims_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(1)); | |||||
(void)AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, "1,3;;4,3;,3"); | |||||
model.is_getnext_sink_dynamic_ = true; | |||||
model.is_online_infer_dynamic_ = false; | |||||
auto ret = model.InitRealSizeAndShapeInfo(graph, output_node); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
model.runtime_param_.mem_base = (uint8_t *)0x08000000; | |||||
model.runtime_param_.mem_size = 4; | |||||
ret = model.InitRealSizeAndShapeInfo(graph, output_node); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -0,0 +1,247 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/multi_batch_clone_pass.h" | |||||
#include <gtest/gtest.h> | |||||
#include <set> | |||||
#include <string> | |||||
#include "inc/pass_manager.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/common/local_context.h" | |||||
#include "graph/passes/multi_batch_pass.h" | |||||
#include "graph/preprocess/multi_batch_copy_graph.h" | |||||
#include "graph/preprocess/insert_op/util_insert_aipp_op.h" | |||||
#include "framework/omg/omg_inner_types.h" | |||||
#include "register/op_registry.h" | |||||
namespace ge{ | |||||
class UtestMultiBatchClonePass : public testing::Test { | |||||
protected: | |||||
void SetUp() { | |||||
SetLocalOmgContext(domi::GetContext()); | |||||
GetLocalOmgContext().dynamic_image_size.clear(); | |||||
GetLocalOmgContext().dynamic_batch_size.clear(); | |||||
} | |||||
void TearDown() { | |||||
GetLocalOmgContext().dynamic_image_size.clear(); | |||||
GetLocalOmgContext().dynamic_batch_size.clear(); | |||||
GetLocalOmgContext().dynamic_node_type.clear(); | |||||
} | |||||
public: | |||||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
for (auto i = 0; i < in_num; ++i) { | |||||
op_desc->AddInputDesc(test_desc); | |||||
} | |||||
for (auto i = 0; i < out_num; ++i) { | |||||
op_desc->AddOutputDesc(test_desc); | |||||
} | |||||
return graph->AddNode(op_desc); | |||||
} | |||||
NodePtr MakeConstNode(const ComputeGraphPtr &graph) { | |||||
static uint32_t index = 0; | |||||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
auto op_desc = std::make_shared<OpDesc>("dynamic_const_" + std::to_string(index++), "Const"); | |||||
op_desc->AddOutputDesc(test_desc); | |||||
return graph->AddNode(op_desc); | |||||
} | |||||
void make_original_graph(const ComputeGraphPtr &graph) { | |||||
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||||
{ | |||||
auto data1 = MakeNode(graph, 1, 1, "data", "Data"); | |||||
GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector<int64_t>{-1,3,224,224})}; | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||||
auto const1 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||||
auto const2 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||||
} | |||||
auto bn_conv1 = MakeNode(graph, 4, 1, "bn_conv1", "BNInference"); | |||||
{ | |||||
GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(0)); | |||||
auto const1 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(1)); | |||||
auto const2 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(2)); | |||||
auto const3= MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const3->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(3)); | |||||
} | |||||
auto scale_conv1 = MakeNode(graph, 4, 1, "scale1", "Scale"); | |||||
{ | |||||
GraphUtils::AddEdge(bn_conv1->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(0)); | |||||
auto const1 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(1)); | |||||
auto const2 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(2)); | |||||
} | |||||
auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); | |||||
GraphUtils::AddEdge(scale_conv1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
} | |||||
void GraphWithJustData(const ComputeGraphPtr &graph) { | |||||
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||||
{ | |||||
auto data1 = MakeNode(graph, 1, 1, "data", "Data"); | |||||
GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector<int64_t>{-1,3,224,224})}; | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||||
auto const1 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||||
auto const2 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||||
} | |||||
auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); | |||||
GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
} | |||||
void GraphWithGetNextNosink(const ComputeGraphPtr &graph) { | |||||
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||||
{ | |||||
auto data1 = MakeNode(graph, 1, 1, "IteratorGetNext_data", "Data"); | |||||
GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector<int64_t>{-1,3,224,224})}; | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||||
auto const1 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||||
auto const2 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||||
} | |||||
auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); | |||||
GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
} | |||||
// getnext has one data and has one out of shape | |||||
void GraphWithGetNextSink(const ComputeGraphPtr &graph) { | |||||
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||||
{ | |||||
auto data1 = MakeNode(graph, 1, 2, "data", "IteratorV2"); | |||||
GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
data1->GetOpDesc()->UpdateOutputDesc(1, shape_desc); | |||||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector<int64_t>{-1,3,224,224})}; | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||||
auto identity = MakeNode(graph, 1, 0, "identity", "Identity"); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(1), identity->GetInDataAnchor(0)); | |||||
auto const1 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||||
auto const2 = MakeConstNode(graph); | |||||
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||||
} | |||||
auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); | |||||
GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
} | |||||
}; | |||||
// graph is nullptr | |||||
TEST_F(UtestMultiBatchClonePass, graph_nullptr) { | |||||
PassManager pass_manager; | |||||
pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass); | |||||
ComputeGraphPtr graph; | |||||
EXPECT_EQ(pass_manager.Run(graph), PARAM_INVALID); | |||||
} | |||||
// graph with subgraph | |||||
TEST_F(UtestMultiBatchClonePass, graph_with_subgraph) { | |||||
PassManager pass_manager; | |||||
pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass); | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
make_original_graph(graph); | |||||
EXPECT_EQ(pass_manager.Run(graph), SUCCESS); | |||||
ComputeGraphPtr owner = std::make_shared<ComputeGraph>("test_owner"); | |||||
auto func_node = MakeNode(owner, 3, 1, "test_if", "If"); | |||||
graph->SetParentNode(func_node); | |||||
graph->SetParentGraph(owner); | |||||
EXPECT_EQ(pass_manager.Run(graph), SUCCESS); | |||||
} | |||||
//graph is uncompute graph, not need to do multi batch | |||||
TEST_F(UtestMultiBatchClonePass, uncompute_graph) { | |||||
MultiBatchClonePass multi_batch_clone; | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
make_original_graph(graph); | |||||
GetLocalOmgContext().need_multi_batch = false; | |||||
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); | |||||
} | |||||
//compute_graph with data from DATA | |||||
TEST_F(UtestMultiBatchClonePass, compute_graph_with_data) { | |||||
MultiBatchClonePass multi_batch_clone; | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
GraphWithJustData(graph); | |||||
GetLocalOmgContext().need_multi_batch = true; | |||||
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); | |||||
GetLocalOmgContext().dynamic_node_type = DATA; | |||||
GetLocalOmgContext().dynamic_dims = "1;2;4;8"; | |||||
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); | |||||
EXPECT_EQ(GetLocalOmgContext().data_nodes.size(), 1); | |||||
} | |||||
//compute_graph with data from GetNext_nosink | |||||
TEST_F(UtestMultiBatchClonePass, compute_graph_with_getnext_nosink) { | |||||
MultiBatchClonePass multi_batch_clone; | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
GraphWithGetNextNosink(graph); | |||||
GetLocalOmgContext().need_multi_batch = true; | |||||
GetLocalOmgContext().dynamic_node_type = GETNEXT; | |||||
GetLocalOmgContext().dynamic_dims = "1;2;4;8"; | |||||
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); | |||||
EXPECT_EQ(GetLocalOmgContext().getnext_nosink_nodes.size(), 1); | |||||
} | |||||
//compute_graph with data from GetNext_nosink | |||||
TEST_F(UtestMultiBatchClonePass, compute_graph_with_getnext_sink) { | |||||
MultiBatchClonePass multi_batch_clone; | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
GraphWithGetNextSink(graph); | |||||
GetLocalOmgContext().need_multi_batch = true; | |||||
GetLocalOmgContext().dynamic_node_type = GETNEXT; | |||||
GetLocalOmgContext().dynamic_dims = "1;2;4;8"; | |||||
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); | |||||
EXPECT_EQ(GetLocalOmgContext().getnext_nosink_nodes.size(), 0); | |||||
} | |||||
} |