| @@ -154,7 +154,6 @@ set(TRAIN_SRC_LIST | |||||
| "graph/passes/compile_nodes_pass.cc" | "graph/passes/compile_nodes_pass.cc" | ||||
| "graph/passes/constant_folding_pass.cc" | "graph/passes/constant_folding_pass.cc" | ||||
| "graph/passes/constant_fuse_same_pass.cc" | "graph/passes/constant_fuse_same_pass.cc" | ||||
| "graph/passes/control_trigger_pass.cc" | "graph/passes/control_trigger_pass.cc" | ||||
| "graph/passes/dimension_adjust_pass.cc" | "graph/passes/dimension_adjust_pass.cc" | ||||
| "graph/passes/dimension_compute_pass.cc" | "graph/passes/dimension_compute_pass.cc" | ||||
| @@ -202,6 +201,7 @@ set(TRAIN_SRC_LIST | |||||
| "host_kernels/sub_kernel.cc" | "host_kernels/sub_kernel.cc" | ||||
| "host_kernels/transdata_kernel.cc" | "host_kernels/transdata_kernel.cc" | ||||
| "host_kernels/unpack_kernel.cc" | "host_kernels/unpack_kernel.cc" | ||||
| "host_kernels/reformat_kernel.cc" | |||||
| "graph/passes/folding_pass.cc" | "graph/passes/folding_pass.cc" | ||||
| "graph/passes/get_original_format_pass.cc" | "graph/passes/get_original_format_pass.cc" | ||||
| "graph/passes/guarantee_const_pass.cc" | "graph/passes/guarantee_const_pass.cc" | ||||
| @@ -488,6 +488,7 @@ set(INFER_SRC_LIST | |||||
| "host_kernels/slice_d_kernel.cc" | "host_kernels/slice_d_kernel.cc" | ||||
| "host_kernels/dynamic_stitch_kernel.cc" | "host_kernels/dynamic_stitch_kernel.cc" | ||||
| "host_kernels/identity_kernel.cc" | "host_kernels/identity_kernel.cc" | ||||
| "host_kernels/reformat_kernel.cc" | |||||
| "graph/passes/stop_gradient_pass.cc" | "graph/passes/stop_gradient_pass.cc" | ||||
| "graph/passes/prevent_gradient_pass.cc" | "graph/passes/prevent_gradient_pass.cc" | ||||
| "graph/passes/identity_pass.cc" | "graph/passes/identity_pass.cc" | ||||
| @@ -139,7 +139,8 @@ int MemoryDumper::OpenFile(const char *filename) { | |||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| -1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos); | -1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos); | ||||
| string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1); | string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH, return kInvalidFd, "Prefix path is too long!"); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH, | |||||
| return kInvalidFd, "Prefix path is too long!"); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmRealPath(prefix_path.c_str(), tmp_path, MMPA_MAX_PATH) != EN_OK, return kInvalidFd, | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmRealPath(prefix_path.c_str(), tmp_path, MMPA_MAX_PATH) != EN_OK, return kInvalidFd, | ||||
| "Dir %s does not exit.", prefix_path.c_str()); | "Dir %s does not exit.", prefix_path.c_str()); | ||||
| real_path = std::string(tmp_path) + last_path;) | real_path = std::string(tmp_path) + last_path;) | ||||
| @@ -123,7 +123,10 @@ Status PluginManager::LoadSo(const string &path, const vector<string> &func_chec | |||||
| if (handle == nullptr) { | if (handle == nullptr) { | ||||
| const char *error = mmDlerror(); | const char *error = mmDlerror(); | ||||
| GE_IF_BOOL_EXEC(error == nullptr, error = ""); | GE_IF_BOOL_EXEC(error == nullptr, error = ""); | ||||
| GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen %s!", error); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, | |||||
| {"mmDlopen", "shared library path is " + FmtToStr(file_path_dlopen) + ". Errormessage" + FmtToStr(error)}); | |||||
| GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen the shared library path[%s]. Errormessage[%s]!", | |||||
| file_path_dlopen.c_str(), error); | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -132,6 +135,9 @@ Status PluginManager::LoadSo(const string &path, const vector<string> &func_chec | |||||
| for (const auto &func_name : func_check_list) { | for (const auto &func_name : func_check_list) { | ||||
| auto real_fn = (void (*)())mmDlsym(handle, const_cast<char *>(func_name.c_str())); | auto real_fn = (void (*)())mmDlsym(handle, const_cast<char *>(func_name.c_str())); | ||||
| if (real_fn == nullptr) { | if (real_fn == nullptr) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, | |||||
| {"mmDlsym", FmtToStr(func_name) + " is skipped since function" + | |||||
| FmtToStr(func_name) + " is not existed!"}); | |||||
| GELOGE(GE_PLGMGR_PATH_INVALID, "%s is skipped since function %s is not existed!", func_name.c_str(), | GELOGE(GE_PLGMGR_PATH_INVALID, "%s is skipped since function %s is not existed!", func_name.c_str(), | ||||
| func_name.c_str()); | func_name.c_str()); | ||||
| is_valid = false; | is_valid = false; | ||||
| @@ -189,7 +189,8 @@ Status ModelHelper::SaveModelHeader(std::shared_ptr<OmFileSaveHelper> &om_file_s | |||||
| err = memcpy_s(model_header.platform_version, PLATFORM_VERSION_LEN, platform_version.c_str(), | err = memcpy_s(model_header.platform_version, PLATFORM_VERSION_LEN, platform_version.c_str(), | ||||
| platform_version.size() + 1); | platform_version.size() + 1); | ||||
| if (err != EOK) { | if (err != EOK) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "ModelHelper SaveModel failed while allocating memory for platform_version."); | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
| "ModelHelper SaveModel failed while allocating memory for platform_version."); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | return ACL_ERROR_GE_MEMORY_ALLOCATION; | ||||
| } | } | ||||
| string version = reinterpret_cast<char *>(model_header.platform_version); | string version = reinterpret_cast<char *>(model_header.platform_version); | ||||
| @@ -180,7 +180,8 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint | |||||
| context_.partition_datas_.push_back(partition); | context_.partition_datas_.push_back(partition); | ||||
| if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) { | if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) { | ||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.", | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, | |||||
| "The partition size %zu is greater than the model data size %u.", | |||||
| partition.size + mem_offset, model_data_size); | partition.size + mem_offset, model_data_size); | ||||
| return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | ||||
| } | } | ||||
| @@ -639,7 +639,8 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | return ACL_ERROR_GE_INTERNAL_ERROR; | ||||
| } | } | ||||
| std::shared_ptr<hybrid::HybridDavinciModel> hybrid_davinci_model = ModelManager::GetInstance()->GetHybridModel(model_id); | |||||
| std::shared_ptr<hybrid::HybridDavinciModel> hybrid_davinci_model = | |||||
| ModelManager::GetInstance()->GetHybridModel(model_id); | |||||
| if (hybrid_davinci_model != nullptr) { | if (hybrid_davinci_model != nullptr) { | ||||
| uint64_t session_id = hybrid_davinci_model->GetSessionId(); | uint64_t session_id = hybrid_davinci_model->GetSessionId(); | ||||
| VarManagerPool::Instance().RemoveVarManager(session_id); | VarManagerPool::Instance().RemoveVarManager(session_id); | ||||
| @@ -164,6 +164,7 @@ OMG_HOST_SRC_FILES := \ | |||||
| host_kernels/slice_d_kernel.cc \ | host_kernels/slice_d_kernel.cc \ | ||||
| host_kernels/dynamic_stitch_kernel.cc \ | host_kernels/dynamic_stitch_kernel.cc \ | ||||
| host_kernels/identity_kernel.cc \ | host_kernels/identity_kernel.cc \ | ||||
| host_kernels/reformat_kernel.cc \ | |||||
| graph/passes/stop_gradient_pass.cc \ | graph/passes/stop_gradient_pass.cc \ | ||||
| graph/passes/prevent_gradient_pass.cc \ | graph/passes/prevent_gradient_pass.cc \ | ||||
| graph/passes/identity_pass.cc \ | graph/passes/identity_pass.cc \ | ||||
| @@ -189,7 +190,6 @@ OMG_HOST_SRC_FILES := \ | |||||
| graph/passes/control_trigger_pass.cc \ | graph/passes/control_trigger_pass.cc \ | ||||
| graph/passes/cond_pass.cc \ | graph/passes/cond_pass.cc \ | ||||
| graph/passes/cond_remove_pass.cc \ | graph/passes/cond_remove_pass.cc \ | ||||
| graph/passes/const_pass.cc \ | |||||
| graph/passes/for_pass.cc \ | graph/passes/for_pass.cc \ | ||||
| graph/passes/enter_pass.cc \ | graph/passes/enter_pass.cc \ | ||||
| graph/passes/assign_pass.cc \ | graph/passes/assign_pass.cc \ | ||||
| @@ -123,7 +123,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| graph/passes/compile_nodes_pass.cc \ | graph/passes/compile_nodes_pass.cc \ | ||||
| graph/passes/constant_folding_pass.cc \ | graph/passes/constant_folding_pass.cc \ | ||||
| graph/passes/constant_fuse_same_pass.cc \ | graph/passes/constant_fuse_same_pass.cc \ | ||||
| graph/passes/const_pass.cc \ | |||||
| graph/passes/control_trigger_pass.cc \ | graph/passes/control_trigger_pass.cc \ | ||||
| graph/passes/dimension_adjust_pass.cc \ | graph/passes/dimension_adjust_pass.cc \ | ||||
| graph/passes/dimension_compute_pass.cc \ | graph/passes/dimension_compute_pass.cc \ | ||||
| @@ -171,6 +170,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| host_kernels/sub_kernel.cc \ | host_kernels/sub_kernel.cc \ | ||||
| host_kernels/transdata_kernel.cc \ | host_kernels/transdata_kernel.cc \ | ||||
| host_kernels/unpack_kernel.cc \ | host_kernels/unpack_kernel.cc \ | ||||
| host_kernels/reformat_kernel.cc \ | |||||
| graph/passes/folding_pass.cc \ | graph/passes/folding_pass.cc \ | ||||
| graph/passes/get_original_format_pass.cc \ | graph/passes/get_original_format_pass.cc \ | ||||
| graph/passes/guarantee_const_pass.cc \ | graph/passes/guarantee_const_pass.cc \ | ||||
| @@ -349,7 +349,8 @@ static Status GenerateTaskForConstant(const std::shared_ptr<ComputeGraph> &graph | |||||
| GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | ||||
| std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | ||||
| if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { | if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { | ||||
| GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); | |||||
| GELOGE(FAILED, "Insert memcpy between %s and %s failed.", | |||||
| in_node->GetName().c_str(), node->GetName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -21,8 +21,8 @@ | |||||
| namespace { | namespace { | ||||
| const uint32_t kRangeCeilInterval = 2; | const uint32_t kRangeCeilInterval = 2; | ||||
| const uint32_t kLogBase = 2; | const uint32_t kLogBase = 2; | ||||
| const int64_t kLargeBlockSize = 8 * 1024 * 1024; | |||||
| const int64_t kLargeBlockRangeSize = 10; | |||||
| const int64_t kLargeBlockSize = 8 * 1024 * 1024; // 8M | |||||
| const int64_t kLargeBlockRangeSize = 2; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| @@ -73,15 +73,17 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) { | |||||
| GELOGE(FAILED, "dividend is 0!"); | GELOGE(FAILED, "dividend is 0!"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // Memory size is 512 aligned, so it is not necessary to take less than 512 | |||||
| int64_t min_memory_size = (all_memory_size.back() > MEM_ALIGN_SIZE) ? MEM_ALIGN_SIZE : all_memory_size.front(); | |||||
| auto range_number = static_cast<size_t>( | auto range_number = static_cast<size_t>( | ||||
| ceil(log(all_memory_size.back() / static_cast<double>(all_memory_size.front())) / log(kLogBase))); | |||||
| ceil(log(all_memory_size.back() / static_cast<double>(min_memory_size)) / log(kLogBase))); | |||||
| range_number = (range_number == 0) ? 1 : range_number; | range_number = (range_number == 0) ? 1 : range_number; | ||||
| GELOGD("Range number: %zu", range_number); | GELOGD("Range number: %zu", range_number); | ||||
| vector<vector<int64_t>> ranges(range_number); | vector<vector<int64_t>> ranges(range_number); | ||||
| GE_CHK_BOOL_EXEC((range_number != 0), return PARAM_INVALID, "range_number can't be 0."); | GE_CHK_BOOL_EXEC((range_number != 0), return PARAM_INVALID, "range_number can't be 0."); | ||||
| size_t range_number_limit = all_memory_size.size() / range_number; | size_t range_number_limit = all_memory_size.size() / range_number; | ||||
| int64_t range_ceil = all_memory_size[0]; | |||||
| int64_t range_ceil = min_memory_size; | |||||
| for (size_t i = 1; i <= range_number; i++) { | for (size_t i = 1; i <= range_number; i++) { | ||||
| GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(static_cast<uint64_t>(range_ceil), kRangeCeilInterval), | GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(static_cast<uint64_t>(range_ceil), kRangeCeilInterval), | ||||
| GELOGE(FAILED, "Multiply result is out of range."); | GELOGE(FAILED, "Multiply result is out of range."); | ||||
| @@ -114,7 +116,7 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) { | |||||
| range_ceils.push_back(range.back()); | range_ceils.push_back(range.back()); | ||||
| } | } | ||||
| } | } | ||||
| GELOGD("Range ceils: %s", ToString(range_ceils).c_str()); | |||||
| GELOGI("Range ceils: %s", ToString(range_ceils).c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -65,6 +65,98 @@ void AlignMemOffset(size_t &mem_align_size) { | |||||
| mem_align_size = (mem_align_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; | mem_align_size = (mem_align_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; | ||||
| } | } | ||||
| static bool CompareLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) { | |||||
| auto left_node_op_desc = left.node->GetOpDesc(); | |||||
| auto right_node_op_desc = right.node->GetOpDesc(); | |||||
| if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr) | |||||
| && (left_node_op_desc->GetId() < right_node_op_desc->GetId())) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void GetLifeList(const MemoryBlock &block, std::vector<NodeTypeIndex> &life_list, bool child) { | |||||
| for (auto &node : block.NodeTypeIndexList()) { | |||||
| life_list.emplace_back(node); | |||||
| } | |||||
| if (child) { | |||||
| for (auto child_block : block.ChildBlockList()) { | |||||
| if (child_block == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (block.stream_id_ != child_block->stream_id_ || !block.same_stream_ || !child_block->same_stream_) { | |||||
| life_list.clear(); | |||||
| return; | |||||
| } | |||||
| GetLifeList(*child_block, life_list, child); | |||||
| } | |||||
| } | |||||
| } | |||||
| bool CrossLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) { | |||||
| if ((left.node == nullptr) || (right.node == nullptr)) { | |||||
| return true; | |||||
| } | |||||
| auto left_node_op_desc = left.node->GetOpDesc(); | |||||
| auto right_node_op_desc = right.node->GetOpDesc(); | |||||
| if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr)) { | |||||
| if (left_node_op_desc->GetId() < right_node_op_desc->GetId()) { | |||||
| if (left.life_time_end >= static_cast<size_t>(right_node_op_desc->GetId())) { | |||||
| return true; | |||||
| } | |||||
| } else if (left_node_op_desc->GetId() == right_node_op_desc->GetId()) { | |||||
| return true; | |||||
| } else { | |||||
| if (right.life_time_end >= static_cast<size_t>(left_node_op_desc->GetId())) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| /// | |||||
| /// When child block's life time are not cross with parent block, they can be reused(only same stream). | |||||
| /// |-----------------------------parent block---------------------| | |||||
| /// |------child block1--------------||------child block2------| | |||||
| /// |--child block1-1-| | |||||
| /// | |||||
| bool CanIntervalLifeReuse(MemoryBlock &parent_block, MemoryBlock &child_block) { | |||||
| // judge by interval life time, only same stream can be judged by interval life time | |||||
| if (parent_block.stream_id_ != child_block.stream_id_ || !parent_block.same_stream_ || !child_block.same_stream_ | |||||
| || parent_block.NodeTypeIndexList().empty() || child_block.NodeTypeIndexList().empty()) { | |||||
| return false; | |||||
| } | |||||
| // quick judge by front and back node | |||||
| if (CrossLifeTime(parent_block.NodeTypeIndexList().front(), child_block.NodeTypeIndexList().front())) { | |||||
| return false; | |||||
| } | |||||
| if (CrossLifeTime(parent_block.NodeTypeIndexList().back(), child_block.NodeTypeIndexList().back())) { | |||||
| return false; | |||||
| } | |||||
| std::vector<NodeTypeIndex> life_list; | |||||
| GetLifeList(parent_block, life_list, false); | |||||
| GetLifeList(child_block, life_list, true); | |||||
| if (life_list.empty()) { | |||||
| return false; | |||||
| } | |||||
| std::sort(life_list.begin(), life_list.end(), CompareLifeTime); | |||||
| size_t pre_life_end = 0; | |||||
| for (auto &node : life_list) { | |||||
| auto node_op_desc = node.node->GetOpDesc(); | |||||
| if (node_op_desc != nullptr && pre_life_end >= static_cast<size_t>(node_op_desc->GetId())) { | |||||
| // life time cross | |||||
| return false; | |||||
| } | |||||
| pre_life_end = node.life_time_end; | |||||
| } | |||||
| GELOGI("Block size[%zu, %zu] life time are not cross.", parent_block.Size(), child_block.Size()); | |||||
| return true; | |||||
| } | |||||
| void MemoryBlock::SetHeadOffset(size_t offset) { | void MemoryBlock::SetHeadOffset(size_t offset) { | ||||
| head_offset_ = offset; | head_offset_ = offset; | ||||
| size_t child_offset = head_offset_; | size_t child_offset = head_offset_; | ||||
| @@ -125,20 +217,12 @@ size_t MemoryBlock::AlignSize() const { | |||||
| return align_block_size; | return align_block_size; | ||||
| } | } | ||||
| bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { | |||||
| if (node_type_index_list_.empty()) { | |||||
| bool MemoryBlock::IsSameBatchLabel() { | |||||
| // only same batch label can reuse | |||||
| if (batch_label_.empty() || node_type_index_list_.empty()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto node_op_desc = node_type_index_list_[0].node->GetOpDesc(); | |||||
| if (node_op_desc == nullptr) { | |||||
| return false; | |||||
| } | |||||
| // not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter | |||||
| (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, first_batch_label); | |||||
| if (first_batch_label.empty()) { | |||||
| return false; | |||||
| } | |||||
| bool all_same_label = true; | bool all_same_label = true; | ||||
| for (size_t index = 1; index < node_type_index_list_.size(); ++index) { | for (size_t index = 1; index < node_type_index_list_.size(); ++index) { | ||||
| if (node_type_index_list_[index].node == nullptr) { | if (node_type_index_list_[index].node == nullptr) { | ||||
| @@ -147,8 +231,9 @@ bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { | |||||
| std::string batch_label; | std::string batch_label; | ||||
| auto index_op_desc = node_type_index_list_[index].node->GetOpDesc(); | auto index_op_desc = node_type_index_list_[index].node->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(index_op_desc == nullptr, continue); | GE_IF_BOOL_EXEC(index_op_desc == nullptr, continue); | ||||
| // not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter | |||||
| (void)ge::AttrUtils::GetStr(index_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | (void)ge::AttrUtils::GetStr(index_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | ||||
| if (first_batch_label != batch_label) { | |||||
| if (batch_label_ != batch_label) { | |||||
| all_same_label = false; | all_same_label = false; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -197,7 +282,7 @@ void MemoryBlock::AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLi | |||||
| } | } | ||||
| void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) { | void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) { | ||||
| if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) { | |||||
| if (CanNotLifeReuse(this) || CanNotLifeReuse(block) || (batch_label_ != block->batch_label_)) { | |||||
| return; | return; | ||||
| } | } | ||||
| if (block->continuous_block_) { | if (block->continuous_block_) { | ||||
| @@ -207,16 +292,27 @@ void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_ | |||||
| MemoryBlock *parent = nullptr; | MemoryBlock *parent = nullptr; | ||||
| MemoryBlock *child = nullptr; | MemoryBlock *child = nullptr; | ||||
| // merge small block to large block | // merge small block to large block | ||||
| if (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()) { | |||||
| if ((child_offset_ + block->AlignSize()) <= AlignSize()) { | |||||
| parent = this; | |||||
| child = block; | |||||
| } else if ((block->child_offset_ + AlignSize()) <= block->AlignSize()) { | |||||
| parent = block; | |||||
| child = this; | |||||
| // noalign size 802816 + 802816 = 1605632 can reuse | |||||
| // after 32 align size 802848 + 802848 > 1605664 can't reuse | |||||
| // after 512 align size 803328 + 803328 > 1606144 can't reuse | |||||
| // so 803328 + 803328 = 1606144 + 512 can reuse | |||||
| if ((child_offset_ + block->AlignSize()) <= (AlignSize() + MEM_ALIGN_SIZE)) { | |||||
| parent = this; | |||||
| child = block; | |||||
| } else if ((block->child_offset_ + AlignSize()) <= (block->AlignSize() + MEM_ALIGN_SIZE)) { | |||||
| parent = block; | |||||
| child = this; | |||||
| } | |||||
| if ((parent != nullptr) && (child != nullptr)) { | |||||
| // Different streams must use stream dependency to judge the life cycle | |||||
| // In case same stream if it has child block, can judge all the child block's life time in CanIntervalLifeReuse | |||||
| bool can_block_life_reuse = (child->child_blocks_.empty() | |||||
| && (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd())); | |||||
| if (!can_block_life_reuse && !CanIntervalLifeReuse(*parent, *child)) { | |||||
| return; | |||||
| } | } | ||||
| } | |||||
| if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) { | |||||
| parent->child_blocks_.emplace_back(child); | parent->child_blocks_.emplace_back(child); | ||||
| parent->child_offset_ += child->AlignSize(); | parent->child_offset_ += child->AlignSize(); | ||||
| child->deleted_block_ = true; | child->deleted_block_ = true; | ||||
| @@ -261,6 +357,7 @@ size_t MemoryBlock::GetDependLifeBegin(int64_t stream_id, DependStreamLife &tota | |||||
| void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id, | void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id, | ||||
| std::map<int64_t, size_t> &depend_stream_life, DependStreamLife &total_node_depend_stream_life) { | std::map<int64_t, size_t> &depend_stream_life, DependStreamLife &total_node_depend_stream_life) { | ||||
| GE_CHECK_NOTNULL_EXEC(node, return); | GE_CHECK_NOTNULL_EXEC(node, return); | ||||
| GE_CHECK_NOTNULL_EXEC(org_node, return); | |||||
| auto node_desc = node->GetOpDesc(); | auto node_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL_EXEC(node_desc, return); | GE_CHECK_NOTNULL_EXEC(node_desc, return); | ||||
| auto node_id = node_desc->GetId(); | auto node_id = node_desc->GetId(); | ||||
| @@ -415,12 +512,60 @@ BlockMemAssigner::~BlockMemAssigner() { | |||||
| } | } | ||||
| } | } | ||||
| void GetMaxBatchAllMemorySize(std::map<std::string, vector<int64_t>> &batch_all_memory_size, | |||||
| std::map<std::string, int64_t> batch_total_size, vector<int64_t> &all_memory_size, | |||||
| std::string &max_batch_label) { | |||||
| // use max batch all memory size for reuse range | |||||
| int64_t max_batch_size = 0; | |||||
| for (const auto &it : batch_total_size) { | |||||
| GELOGI("Batch[%s] total memory size[%ld]", it.first.c_str(), it.second); | |||||
| // no batch label | |||||
| if (it.first.empty()) { | |||||
| continue; | |||||
| } | |||||
| if (it.second > max_batch_size) { | |||||
| max_batch_size = it.second; | |||||
| max_batch_label = it.first; | |||||
| } | |||||
| } | |||||
| GELOGI("Max batch[%s] total memory size[%ld]", max_batch_label.c_str(), max_batch_size); | |||||
| for (const auto &it : batch_all_memory_size) { | |||||
| if (it.first.empty() || (it.first == max_batch_label)) { | |||||
| all_memory_size.insert(all_memory_size.end(), it.second.begin(), it.second.end()); | |||||
| } | |||||
| } | |||||
| // all_memory_size can't be empty | |||||
| if (all_memory_size.empty()) { | |||||
| all_memory_size.emplace_back(MEM_ALIGN_SIZE); | |||||
| } | |||||
| sort(all_memory_size.begin(), all_memory_size.end()); | |||||
| GELOGD("All memory size: %s", ToString(all_memory_size).c_str()); | |||||
| for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { | |||||
| if (*iter == 0) { | |||||
| iter = all_memory_size.erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| } | |||||
| } | |||||
| void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | ||||
| vector<int64_t> temp; | vector<int64_t> temp; | ||||
| std::map<std::string, vector<int64_t>> batch_all_memory_size; | |||||
| std::map<std::string, int64_t> batch_total_size; | |||||
| for (const NodePtr &n : compute_graph_->GetAllNodes()) { | for (const NodePtr &n : compute_graph_->GetAllNodes()) { | ||||
| auto node_op_desc = n->GetOpDesc(); | auto node_op_desc = n->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | ||||
| if (CheckIsZeroMemNodeType(node_op_desc->GetType())) { | |||||
| continue; | |||||
| } | |||||
| std::string batch_label; | |||||
| (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (node_op_desc->GetType() == ATOMICADDRCLEAN) { | if (node_op_desc->GetType() == ATOMICADDRCLEAN) { | ||||
| atomic_addr_clean_id_ = node_op_desc->GetId(); | atomic_addr_clean_id_ = node_op_desc->GetId(); | ||||
| } | } | ||||
| @@ -434,9 +579,14 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | |||||
| if (!reuse_input) { | if (!reuse_input) { | ||||
| int64_t size = 0; | int64_t size = 0; | ||||
| GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed")); | GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed")); | ||||
| if (anchor_to_symbol_.empty()) { | |||||
| all_memory_size.emplace_back(size); | |||||
| batch_all_memory_size[batch_label].emplace_back(size); | |||||
| if (batch_total_size.find(batch_label) == batch_total_size.end()) { | |||||
| batch_total_size[batch_label] = size; | |||||
| } else { | } else { | ||||
| batch_total_size[batch_label] += size; | |||||
| } | |||||
| if (!anchor_to_symbol_.empty()) { | |||||
| auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString()); | auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString()); | ||||
| if (iter1 == anchor_to_symbol_.end()) { | if (iter1 == anchor_to_symbol_.end()) { | ||||
| continue; | continue; | ||||
| @@ -452,23 +602,11 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | |||||
| } | } | ||||
| } | } | ||||
| temp.clear(); | temp.clear(); | ||||
| GetNodeWorkSpaceSize(n, temp); | |||||
| all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end()); | |||||
| } | |||||
| for (const auto &pair : symbol_size_) { | |||||
| all_memory_size.emplace_back(pair.second); | |||||
| } | |||||
| sort(all_memory_size.begin(), all_memory_size.end()); | |||||
| GELOGD("All memory size: %s", ToString(all_memory_size).c_str()); | |||||
| for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { | |||||
| if (*iter == 0) { | |||||
| iter = all_memory_size.erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| GetNodeWorkSpaceSize(n, temp, batch_total_size[batch_label]); | |||||
| batch_all_memory_size[batch_label].insert(batch_all_memory_size[batch_label].end(), temp.begin(), temp.end()); | |||||
| } | } | ||||
| GELOGI("The last atomic_addr_clean node id: %ld", atomic_addr_clean_id_); | |||||
| GetMaxBatchAllMemorySize(batch_all_memory_size, batch_total_size, all_memory_size, max_batch_label_); | |||||
| InitReuseFlag(); | InitReuseFlag(); | ||||
| PrintSymbolMap(); | PrintSymbolMap(); | ||||
| } | } | ||||
| @@ -529,16 +667,6 @@ bool CanReuseBySize(const map<string, uint64_t> &reusable_block_counts, const Me | |||||
| bool can_reuse = false; | bool can_reuse = false; | ||||
| if (reusable_block.Size() == block_size) { | if (reusable_block.Size() == block_size) { | ||||
| can_reuse = true; | can_reuse = true; | ||||
| } else { | |||||
| string key = std::to_string(reusable_block.Size()); | |||||
| key += "_" + std::to_string(reusable_block.stream_id_); | |||||
| key += "_" + std::to_string(reusable_block.memory_type_); | |||||
| auto it = reusable_block_counts.find(key); | |||||
| GE_IF_BOOL_EXEC((it != reusable_block_counts.end() && (it->second > kReuseMaxCount)) && | |||||
| (reusable_block.Size() > block_size), | |||||
| can_reuse = true; | |||||
| GELOGD("Less size mem reuse, reuse block size:%zu, current block size:%zu", | |||||
| reusable_block.Size(), block_size);); | |||||
| } | } | ||||
| return can_reuse; | return can_reuse; | ||||
| } | } | ||||
| @@ -860,17 +988,26 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); | ||||
| auto node_op_desc = n->GetOpDesc(); | auto node_op_desc = n->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); | GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); | ||||
| std::string batch_label; | |||||
| (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (batch_label.empty() || (batch_label == max_batch_label_)) { | |||||
| size_t align_size = real_size; | |||||
| AlignMemOffset(align_size); | |||||
| theory_memory_size_ += align_size; | |||||
| if (theory_memory_size_ > theory_min_memory_size_) { | |||||
| theory_min_memory_size_ = theory_memory_size_; | |||||
| } | |||||
| } | |||||
| bool is_reuse_memory = false; | bool is_reuse_memory = false; | ||||
| string ge_disable_reuse_mem_env = "0"; | |||||
| (void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env); | |||||
| if (ge_disable_reuse_mem_env != "1") { | |||||
| if (ge_disable_reuse_mem_env_ != "1") { | |||||
| bool reuse_mem_flag = (mem_type == kOutput) ? IsPreReuse(n, out_index) : | bool reuse_mem_flag = (mem_type == kOutput) ? IsPreReuse(n, out_index) : | ||||
| !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); | !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); | ||||
| is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && | is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && | ||||
| !node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem; | !node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem; | ||||
| auto stream_id = node_op_desc->GetStreamId(); | |||||
| if (is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty()) { | |||||
| bool do_reuse = is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty(); | |||||
| if (do_reuse) { | |||||
| auto stream_id = node_op_desc->GetStreamId(); | |||||
| for (auto it = reusable_blocks_[memory_type][stream_id].rbegin(); | for (auto it = reusable_blocks_[memory_type][stream_id].rbegin(); | ||||
| it != reusable_blocks_[memory_type][stream_id].rend(); ++it) { | it != reusable_blocks_[memory_type][stream_id].rend(); ++it) { | ||||
| MemoryBlock *reusable_block = *it; | MemoryBlock *reusable_block = *it; | ||||
| @@ -879,15 +1016,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
| GELOGI("Unreusable block."); | GELOGI("Unreusable block."); | ||||
| continue; | continue; | ||||
| } | } | ||||
| std::string batch_label; | |||||
| if (reusable_block->IsSameLabel(batch_label)) { | |||||
| std::string op_label; | |||||
| (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, op_label); | |||||
| if (batch_label != op_label) { | |||||
| GELOGI("label diff, op name %s", node_op_desc->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| GE_IF_BOOL_EXEC(reusable_block->batch_label_ != batch_label, continue); | |||||
| // A node can reuse blocks of the same stream and preorder streams | // A node can reuse blocks of the same stream and preorder streams | ||||
| if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) { | if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) { | ||||
| @@ -914,10 +1043,11 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
| // Data and netoutput need zero copy block | // Data and netoutput need zero copy block | ||||
| block->is_zero_copy_ = IsZeroCopyBlock(n, continuous); | block->is_zero_copy_ = IsZeroCopyBlock(n, continuous); | ||||
| block->Init(real_size, mem_type, n, out_index, no_align_size); | |||||
| block->Init(real_size, mem_type, n, out_index, no_align_size, node_op_desc->GetStreamId()); | |||||
| block->stream_id_ = node_op_desc->GetStreamId(); | block->stream_id_ = node_op_desc->GetStreamId(); | ||||
| block->ref_count_++; | block->ref_count_++; | ||||
| block->continuous_block_ = continuous; | block->continuous_block_ = continuous; | ||||
| block->batch_label_ = batch_label; | |||||
| if (mem_type == kOutput) { | if (mem_type == kOutput) { | ||||
| auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); | auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); | ||||
| if (iter != anchor_to_symbol_.end()) { | if (iter != anchor_to_symbol_.end()) { | ||||
| @@ -945,6 +1075,11 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (CheckIsZeroMemNodeType(n->GetType())) { | |||||
| zero_memory_list_.emplace_back(n, kOutput, index); | |||||
| continue; | |||||
| } | |||||
| int64_t size = 0; | int64_t size = 0; | ||||
| if (ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS) { | if (ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS) { | ||||
| GELOGI("Get size failed"); | GELOGI("Get size failed"); | ||||
| @@ -957,9 +1092,7 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec | |||||
| // only apply total size in first block | // only apply total size in first block | ||||
| if (index != 0) { | if (index != 0) { | ||||
| zero_memory_list_.emplace_back(n, kOutput, index); | zero_memory_list_.emplace_back(n, kOutput, index); | ||||
| } | |||||
| if (index == 0) { | |||||
| } else { | |||||
| NodeIndexIO node_index_io(n, index, kOut); | NodeIndexIO node_index_io(n, index, kOut); | ||||
| auto iter = anchor_to_symbol_.find(node_index_io.ToString()); | auto iter = anchor_to_symbol_.find(node_index_io.ToString()); | ||||
| if (iter != anchor_to_symbol_.end()) { | if (iter != anchor_to_symbol_.end()) { | ||||
| @@ -972,6 +1105,10 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec | |||||
| } | } | ||||
| } | } | ||||
| if (total_size == 0) { | |||||
| return nullptr; | |||||
| } | |||||
| auto block_size = GetBlockSize(total_size, ranges); | auto block_size = GetBlockSize(total_size, ranges); | ||||
| GELOGI("Node[%s] continuous out memory size[%ld] block size[%zu]", node_op_desc->GetName().c_str(), | GELOGI("Node[%s] continuous out memory size[%ld] block size[%zu]", node_op_desc->GetName().c_str(), | ||||
| total_size, block_size); | total_size, block_size); | ||||
| @@ -1119,15 +1256,28 @@ bool IsKnownSubgraphData(const NodePtr &node) { | |||||
| return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | ||||
| } | } | ||||
| void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory) { | |||||
| void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory, | |||||
| bool same_stream) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null."); | ||||
| GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory"); | GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory"); | ||||
| GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory"); | GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory"); | ||||
| --to_release->ref_count_; | --to_release->ref_count_; | ||||
| if (!same_stream) { | |||||
| to_release->same_stream_ = false; | |||||
| } | |||||
| if (to_release->ref_count_ == 0) { | if (to_release->ref_count_ == 0) { | ||||
| to_release->SetLifeTimeEnd(life_time_); | |||||
| reusable_memory.emplace_back(to_release); | |||||
| AddReusableBlockCount(*to_release, reusable_block_counts_); | |||||
| if (to_release->reuse_mem_ && !to_release->RealSizeList().empty()) { | |||||
| if (to_release->batch_label_.empty() || (to_release->batch_label_ == max_batch_label_)) { | |||||
| size_t align_size = to_release->RealSizeList().back(); | |||||
| AlignMemOffset(align_size); | |||||
| theory_memory_size_ -= align_size; | |||||
| } | |||||
| } | |||||
| if (to_release->same_stream_) { | |||||
| to_release->SetLifeTimeEnd(life_time_); | |||||
| reusable_memory.emplace_back(to_release); | |||||
| AddReusableBlockCount(*to_release, reusable_block_counts_); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1167,10 +1317,9 @@ void BlockMemAssigner::ReleaseInputNodeOutMemory(const unordered_map<string, vec | |||||
| node_type_indexs.back().node->GetName().c_str()); | node_type_indexs.back().node->GetName().c_str()); | ||||
| if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) && | if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) && | ||||
| (node_type_indexs.back().index == static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx())) && | |||||
| (node->GetOpDesc()->GetStreamId() == block->stream_id_)) { | |||||
| ReleaseMemory(block, reusable_memory); | |||||
| if (block->ref_count_ == 0) { | |||||
| (node_type_indexs.back().index == static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx()))) { | |||||
| ReleaseMemory(block, reusable_memory, (node->GetOpDesc()->GetStreamId() == block->stream_id_)); | |||||
| if (block->ref_count_ == 0 && block->same_stream_) { | |||||
| SetLastUsedInputMemAttr(node, in_anchor->GetIdx()); | SetLastUsedInputMemAttr(node, in_anchor->GetIdx()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1267,7 +1416,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector | |||||
| bool no_need_assign_memory = ((size == 0) || CheckIsZeroMemNodeType(node->GetType())); | bool no_need_assign_memory = ((size == 0) || CheckIsZeroMemNodeType(node->GetType())); | ||||
| if (!no_need_assign_memory) { | if (!no_need_assign_memory) { | ||||
| out_node_set_continuous_input = | out_node_set_continuous_input = | ||||
| IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index, no_need_assign_memory, reset_zero_copy_flag); | |||||
| IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index, | |||||
| no_need_assign_memory, reset_zero_copy_flag); | |||||
| GE_IF_BOOL_EXEC(!no_need_assign_memory, | GE_IF_BOOL_EXEC(!no_need_assign_memory, | ||||
| no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input);); | no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input);); | ||||
| } | } | ||||
| @@ -1328,7 +1478,8 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
| iter->second[stream_id].clear(); | iter->second[stream_id].clear(); | ||||
| } | } | ||||
| vector<int64_t> temp; | vector<int64_t> temp; | ||||
| GetNodeWorkSpaceSize(n, temp); | |||||
| int64_t tatal_size = 0; | |||||
| GetNodeWorkSpaceSize(n, temp, tatal_size); | |||||
| vector<int64_t> workspace_bytes; | vector<int64_t> workspace_bytes; | ||||
| vector<int64_t> tvm_workspace_memory_type; | vector<int64_t> tvm_workspace_memory_type; | ||||
| bool has_tvm_workspace_mem_type_attr = | bool has_tvm_workspace_mem_type_attr = | ||||
| @@ -1349,7 +1500,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
| bool workspace_skip_flag = false; | bool workspace_skip_flag = false; | ||||
| if (has_tvm_workspace_mem_type_attr && tvm_workspace_memory_type[i] == RT_MEMORY_L1) { | if (has_tvm_workspace_mem_type_attr && tvm_workspace_memory_type[i] == RT_MEMORY_L1) { | ||||
| GELOGI( | GELOGI( | ||||
| "fusion: node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]", | |||||
| "fusion:node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]", | |||||
| node_op_desc->GetName().c_str(), i, tvm_workspace_memory_type[i]); | node_op_desc->GetName().c_str(), i, tvm_workspace_memory_type[i]); | ||||
| workspace_skip_flag = true; | workspace_skip_flag = true; | ||||
| } | } | ||||
| @@ -1380,9 +1531,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
| (void)mem_block; // Fix warning | (void)mem_block; // Fix warning | ||||
| } | } | ||||
| bool merge_dynamic_batch = false; | |||||
| GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), merge_dynamic_batch = MergeDynamicBatchBlocks()); | |||||
| GE_IF_BOOL_EXEC((!(ge_disable_reuse_mem_env_ == "1") && !merge_dynamic_batch), ReuseBlocksByLifeTime(ranges.size())); | |||||
| GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), ReuseBlocksByLifeTime(ranges.size())); | |||||
| AssignContinuousBlocks(); | AssignContinuousBlocks(); | ||||
| ResizeMemoryBlocks(); | ResizeMemoryBlocks(); | ||||
| @@ -1402,92 +1551,19 @@ void BlockMemAssigner::CheckWorkspaceReuse(const vector<bool> &workspace_reuse_f | |||||
| } | } | ||||
| } | } | ||||
| void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector<int64_t> &workspace_memory) { | |||||
| void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector<int64_t> &workspace_memory, | |||||
| int64_t &total_size) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node->GetOpDesc() == nullptr, return, "Op desc is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node->GetOpDesc() == nullptr, return, "Op desc is null."); | ||||
| vector<int64_t> workspace_byte_nums = node->GetOpDesc()->GetWorkspaceBytes(); | vector<int64_t> workspace_byte_nums = node->GetOpDesc()->GetWorkspaceBytes(); | ||||
| GELOGD("node[%s] size:%zu", node->GetOpDesc()->GetName().c_str(), workspace_byte_nums.size()); | GELOGD("node[%s] size:%zu", node->GetOpDesc()->GetName().c_str(), workspace_byte_nums.size()); | ||||
| for (int64_t byte_size : workspace_byte_nums) { | for (int64_t byte_size : workspace_byte_nums) { | ||||
| workspace_memory.emplace_back(byte_size); | workspace_memory.emplace_back(byte_size); | ||||
| total_size += byte_size; | |||||
| GELOGD("push back size:%ld", byte_size); | GELOGD("push back size:%ld", byte_size); | ||||
| } | } | ||||
| } | } | ||||
| // descending order | |||||
| static bool CompareBlockMaxSize(MemoryBlock *left, MemoryBlock *right) { | |||||
| if (left == nullptr || right == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto left_max_size = std::max_element(left->RealSizeList().begin(), left->RealSizeList().end()); | |||||
| if (left_max_size != left->RealSizeList().end()) { | |||||
| auto right_max_size = std::max_element(right->RealSizeList().begin(), right->RealSizeList().end()); | |||||
| if (right_max_size == right->RealSizeList().end() || (*left_max_size > *right_max_size)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void MergeBlocks(std::vector<MemoryBlock *> &dest, std::vector<MemoryBlock *> &src) { | |||||
| for (size_t i = 0; i < dest.size(); ++i) { | |||||
| if (i >= src.size()) { | |||||
| return; | |||||
| } | |||||
| if (dest[i] != nullptr && src[i] != nullptr) { | |||||
| if (!dest[i]->reuse_mem_ || !src[i]->reuse_mem_) { | |||||
| GELOGD("Diff batch's workspace can't be reused, i: %zu, dest[i]: %s, stream: %ld, src[i]: %s, stream: %ld.", | |||||
| i, dest[i]->String().c_str(), dest[i]->stream_id_, src[i]->String().c_str(), src[i]->stream_id_); | |||||
| continue; | |||||
| } | |||||
| for (auto &symbol : src[i]->SymbolList()) { | |||||
| dest[i]->AddSymbol(symbol); | |||||
| } | |||||
| for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) { | |||||
| dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j], | |||||
| src[i]->RealSizeList()[j], | |||||
| src[i]->NoAlignSizeList()[j]); | |||||
| src[i]->deleted_block_ = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| bool BlockMemAssigner::MergeDynamicBatchBlocks() { | |||||
| bool merged = false; | |||||
| std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks; | |||||
| for (auto block : memory_blocks_) { | |||||
| if (block == nullptr) { | |||||
| continue; | |||||
| } | |||||
| std::string batch_label; | |||||
| if (block->IsSameLabel(batch_label)) { | |||||
| dynamic_batch_blocks[batch_label].emplace_back(block); | |||||
| } | |||||
| } | |||||
| auto it = dynamic_batch_blocks.begin(); | |||||
| auto it_max = it; | |||||
| // find max block counts | |||||
| for (; it != dynamic_batch_blocks.end(); ++it) { | |||||
| if (it->second.size() > it_max->second.size()) { | |||||
| it_max = it; | |||||
| } | |||||
| std::sort(it->second.begin(), it->second.end(), CompareBlockMaxSize); | |||||
| } | |||||
| if (it_max != dynamic_batch_blocks.end()) { | |||||
| GELOGD("MergeDynamicBatch %s block counts %zu", it_max->first.c_str(), it_max->second.size()); | |||||
| } | |||||
| for (it = dynamic_batch_blocks.begin(); it != dynamic_batch_blocks.end(); ++it) { | |||||
| if (it != it_max) { | |||||
| GELOGD("MergeDynamicBatch from %s to %s", it->first.c_str(), it_max->first.c_str()); | |||||
| MergeBlocks(it_max->second, it->second); | |||||
| merged = true; | |||||
| } | |||||
| } | |||||
| return merged; | |||||
| } | |||||
| // asending order | // asending order | ||||
| static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) { | static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) { | ||||
| if (left == nullptr || right == nullptr) { | if (left == nullptr || right == nullptr) { | ||||
| @@ -1597,38 +1673,93 @@ void BlockMemAssigner::ReuseBlocksByLifeTime(size_t range_size) { | |||||
| } | } | ||||
| } | } | ||||
| void AddBlockMemOffset(size_t &mem_offset, size_t &p2p_mem_offset, MemoryBlock &block) { | |||||
| if (block.memory_type_ == RT_MEMORY_HBM) { | |||||
| if (block.first_continuous_block_) { | |||||
| mem_offset += MEM_ALIGN_SIZE; | |||||
| } | |||||
| block.Resize(); | |||||
| block.SetHeadOffset(mem_offset); | |||||
| mem_offset += block.Size(); | |||||
| block.SetTailOffset(mem_offset - 1); | |||||
| } else if (block.memory_type_ == RT_MEMORY_P2P_DDR) { | |||||
| if (block.first_continuous_block_) { | |||||
| p2p_mem_offset += MEM_ALIGN_SIZE; | |||||
| } | |||||
| block.Resize(); | |||||
| block.SetHeadOffset(p2p_mem_offset); | |||||
| p2p_mem_offset += block.Size(); | |||||
| block.SetTailOffset(p2p_mem_offset - 1); | |||||
| } | |||||
| } | |||||
| bool DynamicBatchBlockReuse(MemoryBlock &block) { | |||||
| return (block.IsSameBatchLabel() && block.reuse_mem_); | |||||
| } | |||||
| /// | /// | ||||
| /// @ingroup domi_omg | /// @ingroup domi_omg | ||||
| /// @brief traverse memory size, resize, calculate offset | |||||
| /// @brief get max batch memory size, others reuse this block memory | |||||
| /// @param [in&out] memory_blocks_ memory block, after calculating offset | /// @param [in&out] memory_blocks_ memory block, after calculating offset | ||||
| /// |-dynamic batch block batch1| | |||||
| /// |-dynamic batch block batch2----| | |||||
| /// |-dynamic batch block batch3--| | |||||
| /// | /// | ||||
| void BlockMemAssigner::ResizeMemoryBlocks() { | |||||
| for (auto &memory_block : memory_blocks_) { | |||||
| if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_) { | |||||
| void BlockMemAssigner::ResizeDynamicBatchBlocks() { | |||||
| std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks; | |||||
| for (auto block : memory_blocks_) { | |||||
| if (block == nullptr) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (memory_block->memory_type_ == RT_MEMORY_HBM) { | |||||
| if (memory_block->first_continuous_block_) { | |||||
| mem_offset_ += MEM_ALIGN_SIZE; | |||||
| } | |||||
| // when memory is not reuseable, it can't be reused by different branch | |||||
| if (DynamicBatchBlockReuse(*block)) { | |||||
| dynamic_batch_blocks[block->batch_label_].emplace_back(block); | |||||
| } | |||||
| } | |||||
| memory_block->Resize(); | |||||
| memory_block->SetHeadOffset(mem_offset_); | |||||
| mem_offset_ += memory_block->Size(); | |||||
| memory_block->SetTailOffset(mem_offset_ - 1); | |||||
| } else if (memory_block->memory_type_ == RT_MEMORY_P2P_DDR) { | |||||
| if (memory_block->first_continuous_block_) { | |||||
| p2p_mem_offset_ += MEM_ALIGN_SIZE; | |||||
| size_t max_mem_offset = mem_offset_; | |||||
| size_t max_p2p_mem_offset = p2p_mem_offset_; | |||||
| for (auto &batch_blocks : dynamic_batch_blocks) { | |||||
| size_t mem_offset = mem_offset_; | |||||
| size_t p2p_mem_offset = p2p_mem_offset_; | |||||
| for (auto block : batch_blocks.second) { | |||||
| if (block == nullptr || block->deleted_block_ || block->is_zero_copy_) { | |||||
| continue; | |||||
| } | } | ||||
| AddBlockMemOffset(mem_offset, p2p_mem_offset, *block); | |||||
| } | |||||
| if (mem_offset > max_mem_offset) { | |||||
| max_mem_offset = mem_offset; | |||||
| } | |||||
| if (p2p_mem_offset > max_p2p_mem_offset) { | |||||
| max_p2p_mem_offset = p2p_mem_offset; | |||||
| } | |||||
| GELOGI("Batch[%s] offset[%zu] p2p_offset[%zu]", batch_blocks.first.c_str(), mem_offset, p2p_mem_offset); | |||||
| } | |||||
| mem_offset_ = max_mem_offset; | |||||
| p2p_mem_offset_ = max_p2p_mem_offset; | |||||
| } | |||||
| memory_block->Resize(); | |||||
| memory_block->SetHeadOffset(p2p_mem_offset_); | |||||
| p2p_mem_offset_ += memory_block->Size(); | |||||
| memory_block->SetTailOffset(p2p_mem_offset_ - 1); | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief traverse memory size, resize, calculate offset | |||||
| /// @param [in&out] memory_blocks_ memory block, after calculating offset | |||||
| /// |-not dynamic batch block-||-dynamic batch block batch1| |-zero copy block-| | |||||
| /// |-not dynamic batch block-||-dynamic batch block batch2----||-zero copy block-| | |||||
| /// |-not dynamic batch block-||-dynamic batch block batch3--| |-zero copy block-| | |||||
| /// | |||||
| void BlockMemAssigner::ResizeMemoryBlocks() { | |||||
| for (auto &memory_block : memory_blocks_) { | |||||
| if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_ | |||||
| || DynamicBatchBlockReuse(*memory_block)) { | |||||
| continue; | |||||
| } | } | ||||
| AddBlockMemOffset(mem_offset_, p2p_mem_offset_, *memory_block); | |||||
| } | } | ||||
| GELOGD("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu.", | |||||
| mem_offset_, p2p_mem_offset_); | |||||
| ResizeDynamicBatchBlocks(); | |||||
| GELOGI("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu," | |||||
| "theory_min_memory_size %zu", mem_offset_, p2p_mem_offset_, theory_min_memory_size_); | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -1641,7 +1772,7 @@ void BlockMemAssigner::ResizeMemoryBlocks() { | |||||
| /// @return Status result | /// @return Status result | ||||
| /// | /// | ||||
| void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, | void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, | ||||
| size_t real_size, size_t no_align_size, bool child_block) { | |||||
| size_t real_size, size_t no_align_size, int32_t child_block_level) { | |||||
| ge::OpDescPtr op_desc = node_type.node->GetOpDesc(); | ge::OpDescPtr op_desc = node_type.node->GetOpDesc(); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null."); | ||||
| string graph_name = node_type.node->GetOwnerComputeGraph()->GetName(); | string graph_name = node_type.node->GetOwnerComputeGraph()->GetName(); | ||||
| @@ -1689,14 +1820,15 @@ void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, | |||||
| } | } | ||||
| op_desc->SetWorkspace(workspace_list); | op_desc->SetWorkspace(workspace_list); | ||||
| } | } | ||||
| GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]" | |||||
| " noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d] isref[%d].", graph_name.c_str(), | |||||
| GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu] noalignsize[%zu] " | |||||
| "life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d:%d] isref[%d] batch[%s]", graph_name.c_str(), | |||||
| op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(), | op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(), | ||||
| block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block, block->reuse_mem_, | |||||
| block->continuous_block_, block->deleted_block_, node_type.ref_input); | |||||
| block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block_level, block->reuse_mem_, | |||||
| block->continuous_block_, block->is_zero_copy_, block->same_stream_, node_type.ref_input, | |||||
| block->batch_label_.c_str()); | |||||
| } | } | ||||
| void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { | |||||
| void SetBlockOpMemOffset(MemoryBlock *block, int32_t child_block_level) { | |||||
| if (block == nullptr) { | if (block == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -1709,9 +1841,14 @@ void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { | |||||
| real_size = block->RealSizeList()[index]; | real_size = block->RealSizeList()[index]; | ||||
| no_align_size = block->NoAlignSizeList()[index]; | no_align_size = block->NoAlignSizeList()[index]; | ||||
| } | } | ||||
| SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block); | |||||
| SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block_level); | |||||
| index++; | index++; | ||||
| } | } | ||||
| child_block_level++; | |||||
| for (MemoryBlock *child_block : block->ChildBlockList()) { | |||||
| SetBlockOpMemOffset(child_block, child_block_level); | |||||
| } | |||||
| } | } | ||||
| void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { | void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { | ||||
| @@ -1724,16 +1861,13 @@ void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| SetBlockOpMemOffset(memory_block, false); | |||||
| for (MemoryBlock *child_block : memory_block->ChildBlockList()) { | |||||
| SetBlockOpMemOffset(child_block, true); | |||||
| } | |||||
| SetBlockOpMemOffset(memory_block, 0); | |||||
| } | } | ||||
| if (!is_zero_copy) { | if (!is_zero_copy) { | ||||
| for (const NodeTypeIndex &node_type_index : zero_memory_list_) { | for (const NodeTypeIndex &node_type_index : zero_memory_list_) { | ||||
| MemoryBlock block(0, 0); | MemoryBlock block(0, 0); | ||||
| SetOffsetSize(node_type_index, &block, 0, 0, false); | |||||
| SetOffsetSize(node_type_index, &block, 0, 0, 0); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -65,6 +65,7 @@ class MemoryBlock { | |||||
| stream_id_(stream_id), | stream_id_(stream_id), | ||||
| deleted_block_(false), | deleted_block_(false), | ||||
| reuse_mem_(reuse_mem), | reuse_mem_(reuse_mem), | ||||
| same_stream_(true), | |||||
| input_index_(0), | input_index_(0), | ||||
| continuous_block_(false), | continuous_block_(false), | ||||
| first_continuous_block_(false), | first_continuous_block_(false), | ||||
| @@ -85,10 +86,14 @@ class MemoryBlock { | |||||
| symbol_list_.clear(); | symbol_list_.clear(); | ||||
| } | } | ||||
| void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size) { | |||||
| void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size, | |||||
| int64_t stream_id) { | |||||
| real_size_list_.emplace_back(real_size); | real_size_list_.emplace_back(real_size); | ||||
| no_align_size_list_.emplace_back(no_align_size); | no_align_size_list_.emplace_back(no_align_size); | ||||
| node_type_index_list_.emplace_back(node, type, out_index, false); | node_type_index_list_.emplace_back(node, type, out_index, false); | ||||
| if (stream_id != stream_id_) { | |||||
| same_stream_ = false; | |||||
| } | |||||
| } | } | ||||
| size_t Size() const { return block_size_; } | size_t Size() const { return block_size_; } | ||||
| @@ -106,6 +111,12 @@ class MemoryBlock { | |||||
| node_type_index_list_.emplace_back(node_type_index); | node_type_index_list_.emplace_back(node_type_index); | ||||
| real_size_list_.emplace_back(real_size); | real_size_list_.emplace_back(real_size); | ||||
| no_align_size_list_.emplace_back(no_align_size); | no_align_size_list_.emplace_back(no_align_size); | ||||
| if ((node_type_index.node != nullptr) && (node_type_index.node->GetOpDesc() != nullptr)) { | |||||
| auto stream_id = node_type_index.node->GetOpDesc()->GetStreamId(); | |||||
| if (stream_id != stream_id_) { | |||||
| same_stream_ = false; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| void AddSymbol(const std::string &symbol) { | void AddSymbol(const std::string &symbol) { | ||||
| @@ -122,7 +133,7 @@ class MemoryBlock { | |||||
| std::string String(); | std::string String(); | ||||
| bool IsSameLabel(std::string &first_batch_label); | |||||
| bool IsSameBatchLabel(); | |||||
| void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life); | void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life); | ||||
| @@ -142,6 +153,7 @@ class MemoryBlock { | |||||
| int64_t stream_id_; | int64_t stream_id_; | ||||
| bool deleted_block_; | bool deleted_block_; | ||||
| bool reuse_mem_; | bool reuse_mem_; | ||||
| bool same_stream_; | |||||
| uint32_t input_index_; | uint32_t input_index_; | ||||
| bool continuous_block_; | bool continuous_block_; | ||||
| bool first_continuous_block_; | bool first_continuous_block_; | ||||
| @@ -149,6 +161,7 @@ class MemoryBlock { | |||||
| bool is_zero_copy_; | bool is_zero_copy_; | ||||
| std::map<int64_t, size_t> depend_stream_life_; | std::map<int64_t, size_t> depend_stream_life_; | ||||
| int64_t memory_type_; | int64_t memory_type_; | ||||
| std::string batch_label_; | |||||
| private: | private: | ||||
| size_t block_size_; | size_t block_size_; | ||||
| std::vector<size_t> real_size_list_; | std::vector<size_t> real_size_list_; | ||||
| @@ -209,7 +222,7 @@ class BlockMemAssigner : public MemAssigner { | |||||
| void GetOutAndWorkSpaceMem(std::vector<int64_t> &all_memory_size); | void GetOutAndWorkSpaceMem(std::vector<int64_t> &all_memory_size); | ||||
| void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector<int64_t> &workspace_memory); | |||||
| void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector<int64_t> &workspace_memory, int64_t &total_size); | |||||
| /// | /// | ||||
| /// @ingroup GE | /// @ingroup GE | ||||
| @@ -353,7 +366,7 @@ class BlockMemAssigner : public MemAssigner { | |||||
| /// @return void | /// @return void | ||||
| /// @author | /// @author | ||||
| /// | /// | ||||
| void ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory); | |||||
| void ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory, bool same_stream = true); | |||||
| /// | /// | ||||
| /// @ingroup GE | /// @ingroup GE | ||||
| @@ -379,11 +392,11 @@ class BlockMemAssigner : public MemAssigner { | |||||
| /// | /// | ||||
| /// @ingroup GE | /// @ingroup GE | ||||
| /// @brief Merge memory blocks between different batchs | |||||
| /// @brief Resize memory blocks for each batchs | |||||
| /// @return merge or not | /// @return merge or not | ||||
| /// @author | /// @author | ||||
| /// | /// | ||||
| bool MergeDynamicBatchBlocks(); | |||||
| void ResizeDynamicBatchBlocks(); | |||||
| void AssignContinuousBlocks(); | void AssignContinuousBlocks(); | ||||
| @@ -436,6 +449,17 @@ class BlockMemAssigner : public MemAssigner { | |||||
| int64_t atomic_addr_clean_id_ = 0; | int64_t atomic_addr_clean_id_ = 0; | ||||
| size_t theory_min_memory_size_ = 0; | |||||
| size_t theory_memory_size_ = 0; | |||||
| std::string max_batch_label_; | |||||
| /// | |||||
| /// @ [stream1][nodeid] | |||||
| /// @[nodeid] [stream2][nodeid] | |||||
| /// @ [stream2][nodeid] | |||||
| /// | |||||
| DependStreamLife total_node_depend_stream_life_; | DependStreamLife total_node_depend_stream_life_; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -419,7 +419,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
| GE_IF_BOOL_EXEC(is_peer_output_continuous && (peer_output_size != 1), | GE_IF_BOOL_EXEC(is_peer_output_continuous && (peer_output_size != 1), | ||||
| std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | ||||
| " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | ||||
| " requires continuous output. There may be conflict between the two. This node is not supported now."; | |||||
| " requires continuous output. There may be conflict between the two." + | |||||
| "This node is not supported now."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | ||||
| return PARAM_INVALID;); | return PARAM_INVALID;); | ||||
| @@ -429,7 +430,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
| GE_IF_BOOL_EXEC(is_peer_reference, | GE_IF_BOOL_EXEC(is_peer_reference, | ||||
| std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | ||||
| " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | ||||
| " requires continuous output. There may be conflict between the two. This node is not supported now."; | |||||
| " requires continuous output. There may be conflict between the two." + | |||||
| "This node is not supported now."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | ||||
| return PARAM_INVALID;); | return PARAM_INVALID;); | ||||
| @@ -1646,9 +1648,9 @@ ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &node, const ve | |||||
| } | } | ||||
| string atomic_mem_size_str = ss.str(); | string atomic_mem_size_str = ss.str(); | ||||
| GELOGI("[IMAS]SetAtomicCleanAttr : Set graph[%s] atomic_node[%s] output offset [%s] size[%s] streamid[%ld]", | |||||
| GELOGI("[IMAS]SetAtomicCleanAttr : Set %s atomic_node name[%s] output[0] offset to [%s] streamid[%ld] size[%s]", | |||||
| node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), | node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), | ||||
| atomic_mem_start_str.c_str(), atomic_mem_size_str.c_str(), node->GetOpDesc()->GetStreamId()); | |||||
| atomic_mem_start_str.c_str(), node->GetOpDesc()->GetStreamId(), atomic_mem_size_str.c_str()); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -224,7 +224,6 @@ Status ModelBuilder::AdjustConstWeightSize(const ge::NodePtr &node, size_t &mem_ | |||||
| GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); | GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); | ||||
| size_t output_size = weight->GetData().size(); | size_t output_size = weight->GetData().size(); | ||||
| TensorUtils::SetDataOffset(tensor_desc, mem_offset); | TensorUtils::SetDataOffset(tensor_desc, mem_offset); | ||||
| GELOGD("Node: %s, weight size: %zu.", node->GetName().c_str(), output_size); | |||||
| mem_offset += output_size; | mem_offset += output_size; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -49,7 +49,8 @@ inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string & | |||||
| } | } | ||||
| bool IsHcclOp(const string &op_type) { | bool IsHcclOp(const string &op_type) { | ||||
| const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE}); | |||||
| const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, | |||||
| ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE}); | |||||
| return hccl_op_types.find(op_type) != hccl_op_types.end(); | return hccl_op_types.find(op_type) != hccl_op_types.end(); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -283,7 +283,8 @@ Status GraphLoader::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asyn | |||||
| std::vector<GeTensorDesc> &output_desc) { | std::vector<GeTensorDesc> &output_desc) { | ||||
| auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, input_data, input_desc, output_data, output_desc); | |||||
| Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, | |||||
| input_data, input_desc, output_data, output_desc); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Execute model failed, model_id:%u.", model_id); | GELOGE(ret, "Execute model failed, model_id:%u.", model_id); | ||||
| return ret; | return ret; | ||||
| @@ -83,7 +83,7 @@ const uint32_t kAddrLen = sizeof(void *); | |||||
| const int kDecimal = 10; | const int kDecimal = 10; | ||||
| const int kBytes = 8; | const int kBytes = 8; | ||||
| const uint32_t kDataMemAlignSizeCompare = 64; | const uint32_t kDataMemAlignSizeCompare = 64; | ||||
| const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024; | |||||
| const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024; // 2M | |||||
| const uint32_t kDumpFlagOfL1Fusion = 0; | const uint32_t kDumpFlagOfL1Fusion = 0; | ||||
| const char *const kDefaultBatchLable = "Batch_default"; | const char *const kDefaultBatchLable = "Batch_default"; | ||||
| const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; | const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; | ||||
| @@ -330,8 +330,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | |||||
| GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); | GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); | ||||
| return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; | return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; | ||||
| } | } | ||||
| GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | |||||
| mem_base_, data_size); | |||||
| GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", | |||||
| runtime_param_.graph_id, mem_base_, data_size); | |||||
| if (!is_inner_weight_base_) { | if (!is_inner_weight_base_) { | ||||
| weights_mem_base_ = mem_base_; | weights_mem_base_ = mem_base_; | ||||
| @@ -1543,7 +1543,8 @@ Status DavinciModel::LoadWithQueue() { | |||||
| } | } | ||||
| if (output_queue_ids_.size() != new_output_data_info_.size()) { | if (output_queue_ids_.size() != new_output_data_info_.size()) { | ||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID, "Output queue ids not match model: output_queue=%zu output_data=%zu", | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID, | |||||
| "Output queue ids not match model: output_queue=%zu output_data=%zu", | |||||
| output_queue_ids_.size(), new_output_data_info_.size()); | output_queue_ids_.size(), new_output_data_info_.size()); | ||||
| return ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID; | return ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID; | ||||
| } | } | ||||
| @@ -2202,7 +2203,7 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data | |||||
| void *mem_addr = data.second.GetBasicAddr(); | void *mem_addr = data.second.GetBasicAddr(); | ||||
| void *data_buf_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(data_buf.data)); | void *data_buf_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(data_buf.data)); | ||||
| uint64_t data_buf_length = data_buf.length; | uint64_t data_buf_length = data_buf.length; | ||||
| GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]", | |||||
| GELOGI("CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]", | |||||
| runtime_param_.graph_id, data.first, mem_addr, data_buf_addr, data_size, data_buf_length); | runtime_param_.graph_id, data.first, mem_addr, data_buf_addr, data_size, data_buf_length); | ||||
| GE_CHK_RT_RET(rtMemcpy(mem_addr, data_size, data_buf_addr, data_buf_length, kind)); | GE_CHK_RT_RET(rtMemcpy(mem_addr, data_size, data_buf_addr, data_buf_length, kind)); | ||||
| } | } | ||||
| @@ -3391,14 +3392,14 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 | |||||
| /// | /// | ||||
| Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) { | Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) { | ||||
| if (UpdateIoTaskArgs(new_input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { | if (UpdateIoTaskArgs(new_input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "[ZCPY] Update input data to model failed."); | |||||
| return PARAM_INVALID; | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update input data to model failed."); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | } | ||||
| if (UpdateIoTaskArgs(new_output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) != | if (UpdateIoTaskArgs(new_output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) != | ||||
| SUCCESS) { | SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "[ZCPY] Update output data to model failed."); | |||||
| return PARAM_INVALID; | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update output data to model failed."); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | } | ||||
| for (ZeroCopyTask &task : zero_copy_tasks_) { | for (ZeroCopyTask &task : zero_copy_tasks_) { | ||||
| @@ -3861,7 +3862,8 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa | |||||
| if (!is_async_mode_) { | if (!is_async_mode_) { | ||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_START)); | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_START)); | ||||
| ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE); | ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Copy Output data to user failed."); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ACL_ERROR_GE_INTERNAL_ERROR, | |||||
| "Copy Output data to user failed."); | |||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_END)); | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_END)); | ||||
| } | } | ||||
| @@ -4061,7 +4063,7 @@ void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &compute_graph) { | |||||
| data_dumper_.SetDeviceId(device_id); | data_dumper_.SetDeviceId(device_id); | ||||
| // set loop count addr | // set loop count addr | ||||
| auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void * { | |||||
| auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void *{ | |||||
| if (op != nullptr) { | if (op != nullptr) { | ||||
| auto v_output_size = ModelUtils::GetOutputSize(op); | auto v_output_size = ModelUtils::GetOutputSize(op); | ||||
| auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param, op); | auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param, op); | ||||
| @@ -1254,7 +1254,8 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy | |||||
| } | } | ||||
| std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid model id %u.", model_id); | |||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | |||||
| "Invalid model id %u, check weather model has been loaded or not.", model_id); | |||||
| if (davinci_model->NeedDestroyAicpuKernel()) { | if (davinci_model->NeedDestroyAicpuKernel()) { | ||||
| GELOGI("Start to destroy specified aicpu kernel."); | GELOGI("Start to destroy specified aicpu kernel."); | ||||
| @@ -61,7 +61,7 @@ vector<int64_t> ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { | |||||
| GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); | GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); | ||||
| continue); | continue); | ||||
| GELOGI("[IMAS]GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); | |||||
| GELOGI("GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); | |||||
| v_input_size.push_back(tensor_size); | v_input_size.push_back(tensor_size); | ||||
| } | } | ||||
| @@ -96,7 +96,7 @@ vector<int64_t> ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { | |||||
| GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i); | GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i); | ||||
| continue); | continue); | ||||
| GELOGI("[IMAS]GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); | |||||
| GELOGI("GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); | |||||
| v_output_size.push_back(tensor_size); | v_output_size.push_back(tensor_size); | ||||
| } | } | ||||
| @@ -281,7 +281,8 @@ Status HcclTaskInfo::SetAddrs(const std::shared_ptr<OpDesc> &op_desc, | |||||
| kernel_hccl_infos[i].inputDataAddr = input_data_addr; | kernel_hccl_infos[i].inputDataAddr = input_data_addr; | ||||
| if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { | if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { | ||||
| kernel_hccl_infos[i].outputDataAddr = output_data_addr; | kernel_hccl_infos[i].outputDataAddr = output_data_addr; | ||||
| } else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE || hccl_type == HCOMREDUCE) { | |||||
| } else if (hccl_type == HCOMALLREDUCE || | |||||
| hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE || hccl_type == HCOMREDUCE) { | |||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | ||||
| "davinci_model: GetHcomOperationType fail!"); | "davinci_model: GetHcomOperationType fail!"); | ||||
| kernel_hccl_infos[i].outputDataAddr = output_data_addr; | kernel_hccl_infos[i].outputDataAddr = output_data_addr; | ||||
| @@ -1172,8 +1172,8 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u | |||||
| } | } | ||||
| ccStatus_t cc_ret; | ccStatus_t cc_ret; | ||||
| std::string update_kernel_args = "ccUpdateKernelArgs"; | std::string update_kernel_args = "ccUpdateKernelArgs"; | ||||
| auto cceUpdateKernelArgs = (ccStatus_t(*)(ccOpContext &, uint64_t, uint64_t, uint64_t, void *, uint64_t, | |||||
| void *))mmDlsym(handle, const_cast<char *>(update_kernel_args.c_str())); | |||||
| auto cceUpdateKernelArgs = (ccStatus_t(*)(ccOpContext &, uint64_t, uint64_t, | |||||
| uint64_t, void *, uint64_t, void *))mmDlsym(handle, const_cast<char *>(update_kernel_args.c_str())); | |||||
| if (cceUpdateKernelArgs == nullptr) { | if (cceUpdateKernelArgs == nullptr) { | ||||
| GELOGE(FAILED, "Failed to invoke function ccUpdateKernelArgs"); | GELOGE(FAILED, "Failed to invoke function ccUpdateKernelArgs"); | ||||
| if (mmDlclose(handle) != 0) { | if (mmDlclose(handle) != 0) { | ||||
| @@ -56,7 +56,6 @@ | |||||
| #include "graph/passes/cond_remove_pass.h" | #include "graph/passes/cond_remove_pass.h" | ||||
| #include "graph/passes/constant_folding_pass.h" | #include "graph/passes/constant_folding_pass.h" | ||||
| #include "graph/passes/constant_fuse_same_pass.h" | #include "graph/passes/constant_fuse_same_pass.h" | ||||
| #include "graph/passes/const_pass.cc" | |||||
| #include "graph/passes/control_trigger_pass.h" | #include "graph/passes/control_trigger_pass.h" | ||||
| #include "graph/passes/ctrl_edge_transfer_pass.h" | #include "graph/passes/ctrl_edge_transfer_pass.h" | ||||
| #include "graph/passes/dimension_adjust_pass.h" | #include "graph/passes/dimension_adjust_pass.h" | ||||
| @@ -550,8 +549,13 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr | |||||
| if (!op_compile_strategy.empty()) { | if (!op_compile_strategy.empty()) { | ||||
| (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | ||||
| } | } | ||||
| std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | |||||
| compute_graph->GetGraphID(), subgraph, compute_graph, session_id, GetThreadLocalContext()); | |||||
| std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, | |||||
| this, | |||||
| compute_graph->GetGraphID(), | |||||
| subgraph, | |||||
| compute_graph, | |||||
| session_id, | |||||
| GetThreadLocalContext()); | |||||
| if (!f.valid()) { | if (!f.valid()) { | ||||
| GELOGE(FAILED, "Future is invalid"); | GELOGE(FAILED, "Future is invalid"); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -2138,7 +2142,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| TransposeTransDataPass transpose_transdata_pass; | TransposeTransDataPass transpose_transdata_pass; | ||||
| TransOpSymmetryEliminationPass symmetry_elimination_pass; | TransOpSymmetryEliminationPass symmetry_elimination_pass; | ||||
| DimensionComputePass dimension_compute_pass; | DimensionComputePass dimension_compute_pass; | ||||
| ConstPass const_pass; | |||||
| names_to_passes.emplace_back("EnterPass", &enter_pass); | names_to_passes.emplace_back("EnterPass", &enter_pass); | ||||
| names_to_passes.emplace_back("AddNPass", &addn_pass); | names_to_passes.emplace_back("AddNPass", &addn_pass); | ||||
| names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); | names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); | ||||
| @@ -2152,7 +2155,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); | names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); | ||||
| names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | ||||
| names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); | names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); | ||||
| names_to_passes.emplace_back("ConstPass", &const_pass); | |||||
| GE_TIMESTAMP_START(names_to_passes); | GE_TIMESTAMP_START(names_to_passes); | ||||
| ret = GEPass(compute_graph).Run(names_to_passes); | ret = GEPass(compute_graph).Run(names_to_passes); | ||||
| GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2"); | GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2"); | ||||
| @@ -2193,8 +2195,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", | ||||
| new (std::nothrow) VariableRefUselessControlOutDeletePass)) | new (std::nothrow) VariableRefUselessControlOutDeletePass)) | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::CommonSubexpressionEliminationPass", | |||||
| new (std::nothrow) CommonSubexpressionEliminationPass)); | |||||
| if (options_.train_graph_flag) { | if (options_.train_graph_flag) { | ||||
| // Priority: The GlobalStepInsertPass should work before graph partitioner. | // Priority: The GlobalStepInsertPass should work before graph partitioner. | ||||
| // Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory | // Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory | ||||
| @@ -2471,7 +2471,6 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager | |||||
| GetContext().SetSessionId(session_id); | GetContext().SetSessionId(session_id); | ||||
| GetThreadLocalContext() = ge_context; | GetThreadLocalContext() = ge_context; | ||||
| graph_manager->UpdateLocalOmgContext(root_graph_id); | graph_manager->UpdateLocalOmgContext(root_graph_id); | ||||
| ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph(); | ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph(); | ||||
| const std::string &engine_name = sub_graph_info_ptr->GetEngineName(); | const std::string &engine_name = sub_graph_info_ptr->GetEngineName(); | ||||
| GELOGD("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu", | GELOGD("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu", | ||||
| @@ -2479,6 +2478,10 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager | |||||
| pthread_self()); | pthread_self()); | ||||
| GE_DUMP(compute_graph_tmp, "OptimizeSubGraphBefore"); | GE_DUMP(compute_graph_tmp, "OptimizeSubGraphBefore"); | ||||
| GE_CHECK_NOTNULL(compute_graph_tmp); | GE_CHECK_NOTNULL(compute_graph_tmp); | ||||
| if (!AttrUtils::SetInt(*compute_graph_tmp, ATTR_NAME_ROOT_GRAPH_ID, root_graph_id)) { | |||||
| GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_ID for subgraph, graph_id: %u.", root_graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| compute_graph_tmp->SetSessionID(session_id); | compute_graph_tmp->SetSessionID(session_id); | ||||
| Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, | Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, | ||||
| compute_graph, | compute_graph, | ||||
| @@ -263,7 +263,8 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro | |||||
| Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | ||||
| std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) { | |||||
| if (op_desc->GetType() == HCOMBROADCAST || | |||||
| op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) { | |||||
| GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| int64_t root_id = 0; | int64_t root_id = 0; | ||||
| Status dmrt = GetHcclRootId(op_desc, root_id); | Status dmrt = GetHcclRootId(op_desc, root_id); | ||||
| @@ -74,10 +74,87 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input | |||||
| bool AtomicAddrCleanPass::CheckAtomicFromOpsKernel(const NodePtr &node) { | |||||
| // 1.Check if isAtomic attrs exist for HCOM | |||||
| std::shared_ptr<GELib> instance_ptr = GELib::GetInstance(); | |||||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||||
| GELOGW("GELib not initialized, atomic from ops kernel judge false, node_name: %s", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); | |||||
| vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(node->GetType()); | |||||
| for (const auto &op_info : op_info_vec) { | |||||
| if (op_info.isAtomic) { | |||||
| // check peer input is DATA | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| if (in_data_anchor->GetPeerOutAnchor() != nullptr && | |||||
| in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { | |||||
| auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| if (peer_in_node->GetType() == DATA) { | |||||
| GELOGI("Recognized atomic op %s from %s engine and input is DATA.", node->GetName().c_str(), op_info.engine.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| GELOGI("Recognized atomic op %s from %s engine.", node->GetName().c_str(), op_info.engine.c_str()); | |||||
| hcom_node_vec_.push_back(node); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index) { | |||||
| auto out_data_anchor = node->GetAllOutDataAnchors().at(output_index); | |||||
| if (out_data_anchor == nullptr) { | |||||
| return false; | |||||
| } | |||||
| for (auto input_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| auto output_node = input_anchor->GetOwnerNode(); | |||||
| // just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input | |||||
| // hccl's attr ATOMIC_ATTR_INPUT_INDEX mark on CalcOpRunningParam, can't be get here | |||||
| if (CheckAtomicFromOpsKernel(output_node)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) { | |||||
| OpDescPtr op_desc = node->GetOpDesc(); | |||||
| std::map<string, std::map<int, int>> node_workspace_offset; | |||||
| bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | |||||
| bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | |||||
| node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); | |||||
| if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) { | |||||
| std::vector<int64_t> atomic_output_index; | |||||
| (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); | |||||
| bool is_all_output_peer_also_atomic = true; | |||||
| for (const auto &output_index : atomic_output_index) { | |||||
| if (!IsOutputIndexPeerInputAtomic(node, output_index)) { | |||||
| is_all_output_peer_also_atomic = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (is_all_output_peer_also_atomic) { | |||||
| GELOGI("all out peer node input atomic, skip this out atomic process, node name: %s", node->GetName().c_str()); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| Status AtomicAddrCleanPass::HandleLoopGraph(ComputeGraphPtr &graph, const vector<NodePtr> &atomic_node_vec) { | Status AtomicAddrCleanPass::HandleLoopGraph(ComputeGraphPtr &graph, const vector<NodePtr> &atomic_node_vec) { | ||||
| // Loop graph , insert clean node follow atomic node | // Loop graph , insert clean node follow atomic node | ||||
| int index = 0; | int index = 0; | ||||
| for (const auto &node : atomic_node_vec) { | for (const auto &node : atomic_node_vec) { | ||||
| if (CheckSkipInsertInLoopGraph(node)) { | |||||
| continue; | |||||
| } | |||||
| // Insert atomic clean op | // Insert atomic clean op | ||||
| NodePtr clean_addr_node = InsertAtomicAddrCleanNode(graph); | NodePtr clean_addr_node = InsertAtomicAddrCleanNode(graph); | ||||
| if (clean_addr_node == nullptr) { | if (clean_addr_node == nullptr) { | ||||
| @@ -249,32 +326,10 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // 1.Check if isAtomic attrs exist for HCOM | // 1.Check if isAtomic attrs exist for HCOM | ||||
| std::shared_ptr<GELib> instance_ptr = GELib::GetInstance(); | |||||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||||
| GELOGW("GELib not initialized"); | |||||
| return false; | |||||
| if (CheckAtomicFromOpsKernel(node)) { | |||||
| return true; | |||||
| } | } | ||||
| OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); | |||||
| vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); | |||||
| for (const auto &op_info : op_info_vec) { | |||||
| if (op_info.isAtomic) { | |||||
| GELOGI("Recognized atomic op %s from DNN_HCCL engine.", op_desc->GetName().c_str()); | |||||
| // check peer input is DATA | |||||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| if (in_data_anchor->GetPeerOutAnchor() != nullptr && | |||||
| in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { | |||||
| auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| if (peer_in_node->GetType() == DATA) { | |||||
| GELOGI("Recognized atomic op %s from DNN_HCCL engine and input is DATA.", op_desc->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| hcom_node_vec_.push_back(node); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| // 2.Check atomic attr in node | // 2.Check atomic attr in node | ||||
| std::map<string, std::map<int, int>> node_workspace_offset; | std::map<string, std::map<int, int>> node_workspace_offset; | ||||
| bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | ||||
| @@ -84,6 +84,11 @@ class AtomicAddrCleanPass : public GraphPass { | |||||
| Status HandleDispersedAtomicNodes(ComputeGraphPtr &graph, const std::vector<NodePtr> &atomic_node_vec, | Status HandleDispersedAtomicNodes(ComputeGraphPtr &graph, const std::vector<NodePtr> &atomic_node_vec, | ||||
| std::vector<NodePtr> &common_atomic_nodes); | std::vector<NodePtr> &common_atomic_nodes); | ||||
| bool CheckAtomicFromOpsKernel(const NodePtr &node); | |||||
| bool IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index); | |||||
| bool CheckSkipInsertInLoopGraph(const NodePtr &node); | |||||
| vector<NodePtr> hcom_node_vec_; | vector<NodePtr> hcom_node_vec_; | ||||
| bool is_loop_graph_ = false; | bool is_loop_graph_ = false; | ||||
| @@ -18,8 +18,6 @@ | |||||
| #include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| using std::string; | |||||
| namespace ge { | namespace ge { | ||||
| Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | ||||
| GELOGD("AttachStreamLabelPass Enter."); | GELOGD("AttachStreamLabelPass Enter."); | ||||
| @@ -189,10 +187,21 @@ Status AttachStreamLabelPass::UpdateEnterNode() { | |||||
| } | } | ||||
| std::stack<NodePtr> enter_nodes; | std::stack<NodePtr> enter_nodes; | ||||
| std::string batch_label; | |||||
| for (const auto &enter_node : pair.second) { | for (const auto &enter_node : pair.second) { | ||||
| enter_nodes.emplace(enter_node); | enter_nodes.emplace(enter_node); | ||||
| std::string tmp_label; | |||||
| (void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); | |||||
| if (!tmp_label.empty()) { | |||||
| if (batch_label.empty()) { | |||||
| batch_label = tmp_label; | |||||
| } else if (batch_label != tmp_label) { | |||||
| GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { | |||||
| if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) { | |||||
| GELOGE(FAILED, "Update stream_label for loop_branch failed."); | GELOGE(FAILED, "Update stream_label for loop_branch failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -217,7 +226,10 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||||
| } | } | ||||
| for (const auto &enter_node : enter_nodes) { | for (const auto &enter_node : enter_nodes) { | ||||
| GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | |||||
| GE_CHECK_NOTNULL(enter_node->GetOpDesc()); | |||||
| if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) { | |||||
| GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -229,7 +241,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||||
| /// @param [in] batch_label | /// @param [in] batch_label | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const string &stream_label) { | |||||
| Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label, | |||||
| const std::string &batch_label) { | |||||
| std::stack<NodePtr> nodes(enter_nodes); | std::stack<NodePtr> nodes(enter_nodes); | ||||
| NodePtr cur_node = nullptr; | NodePtr cur_node = nullptr; | ||||
| while (!nodes.empty()) { | while (!nodes.empty()) { | ||||
| @@ -238,6 +251,11 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_ | |||||
| for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { | for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { | ||||
| OpDescPtr out_desc = out_node->GetOpDesc(); | OpDescPtr out_desc = out_node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(out_desc); | GE_CHECK_NOTNULL(out_desc); | ||||
| std::string tmp_label; | |||||
| (void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label); | |||||
| if (!tmp_label.empty() && (tmp_label != batch_label)) { | |||||
| continue; | |||||
| } | |||||
| std::string out_type = out_desc->GetType(); | std::string out_type = out_desc->GetType(); | ||||
| bool need_skip = | bool need_skip = | ||||
| out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) || | out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) || | ||||
| @@ -58,9 +58,11 @@ class AttachStreamLabelPass : public GraphPass { | |||||
| /// @brief Update stream_label for loop_branch | /// @brief Update stream_label for loop_branch | ||||
| /// @param [in] enter_nodes | /// @param [in] enter_nodes | ||||
| /// @param [in] stream_label | /// @param [in] stream_label | ||||
| /// @param [in] batch_label | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label); | |||||
| static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label, | |||||
| const std::string &batch_label); | |||||
| /// | /// | ||||
| /// @brief Update stream_label start with enter nodes | /// @brief Update stream_label start with enter nodes | ||||
| @@ -96,7 +96,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder | |||||
| node->GetName().c_str(), node->GetType().c_str()); | node->GetName().c_str(), node->GetType().c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (node_to_re_pass->IsAllInNodesSeen(nodes_seen) || node_to_re_pass->GetType() == ENTER) { | |||||
| if (node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { | |||||
| GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); | GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); | ||||
| nodes_re_pass.insert(node_to_re_pass); | nodes_re_pass.insert(node_to_re_pass); | ||||
| } else { | } else { | ||||
| @@ -58,8 +58,7 @@ std::string GetCseKey(const NodePtr &node) { | |||||
| /// To avoid delete wrong nodes(e.g. stateful nodes), | /// To avoid delete wrong nodes(e.g. stateful nodes), | ||||
| /// only nodes have folding kernel will be considered for the CSE process | /// only nodes have folding kernel will be considered for the CSE process | ||||
| bool IsNodeSupportCse(const NodePtr &node) { | bool IsNodeSupportCse(const NodePtr &node) { | ||||
| if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node)) || node->GetType() == CONSTANT || | |||||
| node->GetType() == CONSTANTOP) { | |||||
| if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node))) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| return folding_pass::GetKernelByType(node) != nullptr; | return folding_pass::GetKernelByType(node) != nullptr; | ||||
| @@ -1,55 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/passes/const_pass.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| namespace ge { | |||||
| Status ConstPass::Run(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if ((node->GetType() != CONSTANT) && (node->GetType() != CONSTANTOP)) { | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGD("ConstPass running, node: %s.", node->GetName().c_str()); | |||||
| // const has no control input | |||||
| if (node->GetInControlNodes().empty()) { | |||||
| auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| if (out_ctrl_anchor != nullptr) { | |||||
| GELOGD("Node: %s unlink all out control edge.", node->GetName().c_str()); | |||||
| out_ctrl_anchor->UnlinkAll(); | |||||
| } | |||||
| if (node->GetOutAllNodes().empty()) { | |||||
| // it is an isolated const, just remove it. | |||||
| GELOGD("Delete isolated const: %s.", node->GetName().c_str()); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Remove const %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| AddNodeDeleted(node); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,29 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_CONST_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_CONST_PASS_H_ | |||||
| #include "graph/passes/base_pass.h" | |||||
| namespace ge { | |||||
| class ConstPass : public BaseNodePass { | |||||
| public: | |||||
| Status Run(NodePtr &node) override; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_CONST_PASS_H_ | |||||
| @@ -80,71 +80,7 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| ret = DealWithInNodes(node); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "DealWithInNodes of %s failed.", node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| std::vector<int> data_relink_io_map = {kDataInputIndex}; | std::vector<int> data_relink_io_map = {kDataInputIndex}; | ||||
| return IsolateAndDeleteNode(node, data_relink_io_map); | return IsolateAndDeleteNode(node, data_relink_io_map); | ||||
| } | } | ||||
| Status DimensionAdjustPass::DealWithInNodes(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| auto in_data_anchors = node->GetAllInDataAnchors(); | |||||
| for (auto &in_data_anchor : in_data_anchors) { | |||||
| if (in_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto in_node_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (in_node_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto in_node = in_node_anchor->GetOwnerNode(); | |||||
| if (in_node->GetType() == SWITCHN) { | |||||
| GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); | |||||
| auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx()); | |||||
| auto identity = | |||||
| AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph); | |||||
| GE_CHECK_NOTNULL(identity); | |||||
| GELOGI("Create new identity node[%s] success.", identity->GetName().c_str()); | |||||
| GE_CHK_STATUS_RET(GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0))) | |||||
| GE_CHECK_NOTNULL(identity->GetOutControlAnchor()); | |||||
| if (identity->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor())) { | |||||
| continue; | |||||
| } | |||||
| GE_CHK_STATUS_RET(GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor())) | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| NodePtr DimensionAdjustPass::AddIdentityNodeToGraph(const string &name, const GeTensorDesc &tensor, | |||||
| ComputeGraphPtr &graph) { | |||||
| if (graph == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Comput graph ptr is null in creating identity node."); | |||||
| return nullptr; | |||||
| } | |||||
| OpDescPtr desc = MakeShared<OpDesc>("", ""); | |||||
| if (desc == nullptr) { | |||||
| GELOGE(MEMALLOC_FAILED, "Failed to create op desc."); | |||||
| return nullptr; | |||||
| } | |||||
| desc->SetName(name); | |||||
| desc->SetType(IDENTITY); | |||||
| auto ret = desc->AddInputDesc(tensor); | |||||
| auto ret2 = desc->AddOutputDesc(tensor); | |||||
| if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating identity."); | |||||
| return nullptr; | |||||
| } | |||||
| return graph->AddNodeFront(desc); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -34,10 +34,6 @@ namespace ge { | |||||
| class DimensionAdjustPass : public BaseNodePass { | class DimensionAdjustPass : public BaseNodePass { | ||||
| public: | public: | ||||
| Status Run(ge::NodePtr &node) override; | Status Run(ge::NodePtr &node) override; | ||||
| private: | |||||
| Status DealWithInNodes(ge::NodePtr &node); | |||||
| NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -23,7 +23,6 @@ | |||||
| namespace { | namespace { | ||||
| const size_t kOutNodesNum = 1; | const size_t kOutNodesNum = 1; | ||||
| const size_t kInCtrlNodesNum = 1; | |||||
| } | } | ||||
| namespace ge { | namespace ge { | ||||
| @@ -56,7 +55,6 @@ Status EnterPass::Run(NodePtr &node) { | |||||
| if (out_ctrl_node == nullptr) { | if (out_ctrl_node == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| GELOGD("Remove control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str()); | |||||
| if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(), | GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(), | ||||
| out_ctrl_node->GetName().c_str()); | out_ctrl_node->GetName().c_str()); | ||||
| @@ -64,12 +62,8 @@ Status EnterPass::Run(NodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| if (OptimizeEnterWithOnlyOutData(node, in_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Optimize enter node[%s] with only out data node failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (UnlinkCtrlEdgeBeforeConst(node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Unlink control edge before const of node[%s]'s out nodes failed.", node->GetName().c_str()); | |||||
| if (OptimizeEnter(node, in_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Optimize enter node[%s] failed.", node->GetName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -78,7 +72,7 @@ Status EnterPass::Run(NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node) { | |||||
| Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { | |||||
| if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { | if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -89,45 +83,17 @@ Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node) | |||||
| } | } | ||||
| GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | ||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))) | |||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); | |||||
| const auto &out_data_anchor = node->GetOutDataAnchor(0); | const auto &out_data_anchor = node->GetOutDataAnchor(0); | ||||
| GE_CHECK_NOTNULL(out_data_anchor); | GE_CHECK_NOTNULL(out_data_anchor); | ||||
| for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | ||||
| GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)) | |||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)) | |||||
| GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); | |||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); | |||||
| } | } | ||||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)) | |||||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); | |||||
| AddNodeDeleted(node); | AddNodeDeleted(node); | ||||
| AddRePassNodesWithInOut(in_node); | AddRePassNodesWithInOut(in_node); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status EnterPass::UnlinkCtrlEdgeBeforeConst(NodePtr &node) { | |||||
| auto out_ctrl_nodes = node->GetOutControlNodes(); | |||||
| if (out_ctrl_nodes.empty()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| GE_CHECK_NOTNULL(out_ctrl_anchor); | |||||
| for (auto &out_ctrl_node : out_ctrl_nodes) { | |||||
| GE_CHECK_NOTNULL(out_ctrl_node); | |||||
| if ((out_ctrl_node->GetType() != CONSTANT) && (out_ctrl_node->GetType() != CONSTANTOP)) { | |||||
| continue; | |||||
| } | |||||
| auto in_ctrl_nodes = out_ctrl_node->GetInControlNodes(); | |||||
| if (in_ctrl_nodes.size() != kInCtrlNodesNum) { | |||||
| continue; | |||||
| } | |||||
| GE_CHK_STATUS_RET(out_ctrl_anchor->Unlink(out_ctrl_node->GetInControlAnchor())) | |||||
| auto out_nodes_of_const = out_ctrl_node->GetOutAllNodes(); | |||||
| for (auto &out_node_of_const : out_nodes_of_const) { | |||||
| if (!out_ctrl_anchor->IsLinkedWith(out_node_of_const->GetInControlAnchor())) { | |||||
| GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(out_node_of_const->GetInControlAnchor())) | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -25,8 +25,7 @@ class EnterPass : public BaseNodePass { | |||||
| Status Run(NodePtr &node) override; | Status Run(NodePtr &node) override; | ||||
| private: | private: | ||||
| Status OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node); | |||||
| Status UnlinkCtrlEdgeBeforeConst(NodePtr &node); | |||||
| Status OptimizeEnter(NodePtr &node, NodePtr &in_node); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_ENTER_PASS_H_ | #endif // GE_GRAPH_PASSES_ENTER_PASS_H_ | ||||
| @@ -173,7 +173,10 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto in_node = in_node_anchor->GetOwnerNode(); | auto in_node = in_node_anchor->GetOwnerNode(); | ||||
| if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) { | |||||
| if (in_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) { | |||||
| GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); | GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); | ||||
| auto ret = in_node_anchor->Unlink(in_data_anchor); | auto ret = in_node_anchor->Unlink(in_data_anchor); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -89,6 +89,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co | |||||
| GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); | GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); | ||||
| } | } | ||||
| if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { | |||||
| string batch_label; | |||||
| (void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (!batch_label.empty()) { | |||||
| auto stream_merge_desc = stream_merge->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(stream_merge_desc); | |||||
| (void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| } | |||||
| } | |||||
| return AddActiveNodes(graph, stream_merge); | return AddActiveNodes(graph, stream_merge); | ||||
| } | } | ||||
| @@ -19,8 +19,6 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| using std::string; | |||||
| namespace ge { | namespace ge { | ||||
| Status NextIterationPass::Run(ComputeGraphPtr graph) { | Status NextIterationPass::Run(ComputeGraphPtr graph) { | ||||
| GELOGD("NextIterationPass Enter"); | GELOGD("NextIterationPass Enter"); | ||||
| @@ -37,6 +35,10 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| if (GroupWithNoBatch(graph) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (FindWhileGroups() != SUCCESS) { | if (FindWhileGroups() != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "Find while groups failed."); | GELOGE(INTERNAL_ERROR, "Find while groups failed."); | ||||
| @@ -71,22 +73,75 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| string batch_label; | |||||
| if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
| frame_name += batch_label; | |||||
| std::string batch_label; | |||||
| (void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (batch_label.empty()) { | |||||
| auto frame_iter = frame_enter_map_.find(frame_name); | |||||
| if (frame_iter == frame_enter_map_.end()) { | |||||
| std::vector<NodePtr> enter_nodes; | |||||
| enter_nodes.emplace_back(enter_node); | |||||
| frame_enter_map_[frame_name] = enter_nodes; | |||||
| } else { | |||||
| frame_iter->second.emplace_back(enter_node); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | } | ||||
| auto iter = loop_group_map_.find(frame_name); | |||||
| if (iter == loop_group_map_.end()) { | |||||
| auto group_iter = loop_group_map_.find(frame_name); | |||||
| if (group_iter == loop_group_map_.end()) { | |||||
| LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | ||||
| if (loop_group == nullptr) { | if (loop_group == nullptr) { | ||||
| GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| loop_group->enter_nodes.emplace_back(enter_node); | loop_group->enter_nodes.emplace_back(enter_node); | ||||
| loop_group_map_[frame_name] = loop_group; | |||||
| loop_group_map_[frame_name][batch_label] = loop_group; | |||||
| } else { | } else { | ||||
| iter->second->enter_nodes.emplace_back(enter_node); | |||||
| auto batch_iter = group_iter->second.find(batch_label); | |||||
| if (batch_iter == group_iter->second.end()) { | |||||
| LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | |||||
| if (loop_group == nullptr) { | |||||
| GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | |||||
| return FAILED; | |||||
| } | |||||
| loop_group->enter_nodes.emplace_back(enter_node); | |||||
| group_iter->second[batch_label] = loop_group; | |||||
| } else { | |||||
| batch_iter->second->enter_nodes.emplace_back(enter_node); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @brief Group Enter nodes without batch_label attr | |||||
| /// @param [in] compute_graph | |||||
| /// @return Status | |||||
| /// | |||||
| Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) { | |||||
| if (frame_enter_map_.empty()) { | |||||
| GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| for (const auto &item : frame_enter_map_) { | |||||
| const std::string &frame_name = item.first; | |||||
| auto iter = loop_group_map_.find(frame_name); | |||||
| if (iter == loop_group_map_.end()) { | |||||
| LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | |||||
| if (loop_group == nullptr) { | |||||
| GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | |||||
| return FAILED; | |||||
| } | |||||
| loop_group->enter_nodes = item.second; | |||||
| loop_group_map_[frame_name][""] = loop_group; | |||||
| } else { | |||||
| for (auto &batch_item : iter->second) { | |||||
| for (const auto &enter_node : item.second) { | |||||
| batch_item.second->enter_nodes.emplace_back(enter_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -99,39 +154,55 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { | |||||
| Status NextIterationPass::FindWhileGroups() { | Status NextIterationPass::FindWhileGroups() { | ||||
| for (const auto &loop_group_iter : loop_group_map_) { | for (const auto &loop_group_iter : loop_group_map_) { | ||||
| const std::string &frame_name = loop_group_iter.first; | const std::string &frame_name = loop_group_iter.first; | ||||
| for (const auto &enter_node : loop_group_iter.second->enter_nodes) { | |||||
| for (const auto &out_node : enter_node->GetOutAllNodes()) { | |||||
| const string &type = out_node->GetType(); | |||||
| if ((type != MERGE) && (type != REFMERGE)) { | |||||
| continue; | |||||
| } | |||||
| NodePtr next_node = nullptr; | |||||
| if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | |||||
| NodePtr switch_node = nullptr; | |||||
| if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (switch_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| NodePtr loop_cond = nullptr; | |||||
| if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (loop_group_iter.second->loop_cond == nullptr) { | |||||
| loop_group_iter.second->loop_cond = loop_cond; | |||||
| } else if (loop_group_iter.second->loop_cond != loop_cond) { | |||||
| GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str()); | |||||
| return FAILED; | |||||
| for (const auto &batch_iter : loop_group_iter.second) { | |||||
| const std::string &batch_label = batch_iter.first; | |||||
| for (const auto &enter_node : batch_iter.second->enter_nodes) { | |||||
| for (const auto &out_node : enter_node->GetOutAllNodes()) { | |||||
| GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(), | |||||
| frame_name.c_str(), batch_label.c_str()); | |||||
| if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) { | |||||
| continue; | |||||
| } | |||||
| std::string tmp_label; | |||||
| GE_CHECK_NOTNULL(out_node->GetOpDesc()); | |||||
| (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); | |||||
| bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); | |||||
| if (need_skip) { | |||||
| continue; | |||||
| } | |||||
| NodePtr next_node = nullptr; | |||||
| if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s", | |||||
| out_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | |||||
| NodePtr switch_node = nullptr; | |||||
| if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s", | |||||
| out_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (switch_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| NodePtr loop_cond = nullptr; | |||||
| if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s", | |||||
| switch_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (batch_iter.second->loop_cond == nullptr) { | |||||
| batch_iter.second->loop_cond = loop_cond; | |||||
| } else if (batch_iter.second->loop_cond != loop_cond) { | |||||
| GELOGE(FAILED, "Multi LoopCond nodes exist."); | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -152,17 +223,19 @@ bool NextIterationPass::VerifyWhileGroup() { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); | GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (loop_group_iter.second->loop_cond == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); | |||||
| return false; | |||||
| } | |||||
| for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) { | |||||
| if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", | |||||
| frame_name.c_str()); | |||||
| for (const auto &batch_iter : loop_group_iter.second) { | |||||
| if (batch_iter.second->loop_cond == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| for (const auto &pair_iter : batch_iter.second->merge_next_pairs) { | |||||
| if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", | |||||
| frame_name.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -176,53 +249,56 @@ bool NextIterationPass::VerifyWhileGroup() { | |||||
| /// | /// | ||||
| Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | ||||
| for (const auto &loop_cond_iter : loop_group_map_) { | for (const auto &loop_cond_iter : loop_group_map_) { | ||||
| const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | |||||
| GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | |||||
| // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | |||||
| NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); | |||||
| NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); | |||||
| if ((enter_active == nullptr) || (next_active == nullptr)) { | |||||
| GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| for (const auto &enter_node : loop_cond_iter.second->enter_nodes) { | |||||
| // Enter --> Active | |||||
| if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(), | |||||
| enter_active->GetName().c_str()); | |||||
| for (const auto &batch_iter : loop_cond_iter.second) { | |||||
| const std::string &cond_name = batch_iter.second->loop_cond->GetName(); | |||||
| GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | |||||
| // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | |||||
| NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); | |||||
| NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); | |||||
| if ((enter_active == nullptr) || (next_active == nullptr)) { | |||||
| GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| } | |||||
| for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | |||||
| NodePtr merge_node = pair.first; | |||||
| NodePtr next_node = pair.second; | |||||
| // Active --> Merge | |||||
| if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| for (const auto &enter_node : batch_iter.second->enter_nodes) { | |||||
| // Enter --> Active | |||||
| if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != | |||||
| GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | } | ||||
| // NextIteration --> Active | |||||
| if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| for (const auto &pair : batch_iter.second->merge_next_pairs) { | |||||
| NodePtr merge_node = pair.first; | |||||
| NodePtr next_node = pair.second; | |||||
| // Active --> Merge | |||||
| if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != | |||||
| GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| // NextIteration --> Active | |||||
| if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| // break link between NextIteration and Merge | |||||
| if (BreakNextIteration(next_node, merge_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | } | ||||
| // break link between NextIteration and Merge | |||||
| if (BreakNextIteration(next_node, merge_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); | |||||
| if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | |||||
| (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { | |||||
| GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | |||||
| (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { | |||||
| GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -289,11 +365,12 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & | |||||
| /// @param [in] node | /// @param [in] node | ||||
| /// @param [in] target_type | /// @param [in] target_type | ||||
| /// @param [in] is_input | /// @param [in] is_input | ||||
| /// @param [in] batch_label | |||||
| /// @param [out] target_node | /// @param [out] target_node | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, | Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, | ||||
| NodePtr &target_node) { | |||||
| const std::string &batch_label, NodePtr &target_node) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| GELOGE(PARAM_INVALID, "node is null."); | GELOGE(PARAM_INVALID, "node is null."); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -310,6 +387,12 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string | |||||
| } | } | ||||
| for (const auto &tmp_node : nodes) { | for (const auto &tmp_node : nodes) { | ||||
| std::string tmp_label; | |||||
| (void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); | |||||
| bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); | |||||
| if (need_skip) { | |||||
| continue; | |||||
| } | |||||
| const std::string type = tmp_node->GetType(); | const std::string type = tmp_node->GetType(); | ||||
| if ((target_type == LOOPCOND) && (type == target_type)) { | if ((target_type == LOOPCOND) && (type == target_type)) { | ||||
| target_node = tmp_node; | target_node = tmp_node; | ||||
| @@ -332,6 +415,7 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string | |||||
| /// @return SUCCESS | /// @return SUCCESS | ||||
| /// | /// | ||||
| Status NextIterationPass::ClearStatus() { | Status NextIterationPass::ClearStatus() { | ||||
| frame_enter_map_.clear(); | |||||
| loop_group_map_.clear(); | loop_group_map_.clear(); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -46,6 +46,13 @@ class NextIterationPass : public GraphPass { | |||||
| /// | /// | ||||
| Status GroupEnterNode(const NodePtr &enter_node); | Status GroupEnterNode(const NodePtr &enter_node); | ||||
| /// | |||||
| /// @brief Group Enter nodes without batch_label attr | |||||
| /// @param [in] compute_graph | |||||
| /// @return Status | |||||
| /// | |||||
| Status GroupWithNoBatch(const ComputeGraphPtr &graph); | |||||
| /// | /// | ||||
| /// @brief Find while groups | /// @brief Find while groups | ||||
| /// @return Status | /// @return Status | ||||
| @@ -90,10 +97,13 @@ class NextIterationPass : public GraphPass { | |||||
| /// @param [out] target_node | /// @param [out] target_node | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); | |||||
| Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, | |||||
| const std::string &batch_label, NodePtr &target_node); | |||||
| // map<frame_name, LoopCondGroup> | |||||
| std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | |||||
| // map<frame_name, vector<enter_node>> | |||||
| std::unordered_map<std::string, std::vector<NodePtr>> frame_enter_map_; | |||||
| // map<frame_name, map<batch_label, LoopCondGroup>> | |||||
| std::unordered_map<std::string, std::unordered_map<std::string, LoopCondGroupPtr>> loop_group_map_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ | #endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ | ||||
| @@ -149,10 +149,10 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node | |||||
| // 5. While->NetOutput in known subgraph | // 5. While->NetOutput in known subgraph | ||||
| std::string op_type; | std::string op_type; | ||||
| bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || | bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || | ||||
| IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || | |||||
| ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) || | |||||
| (!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) && | |||||
| (kWhileOpTypes.count(in_node->GetType()) != 0)); | |||||
| IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || | |||||
| ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) || | |||||
| (!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) && | |||||
| (kWhileOpTypes.count(in_node->GetType()) != 0)); | |||||
| if (insert_flag) { | if (insert_flag) { | ||||
| GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | ||||
| std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | ||||
| @@ -70,8 +70,10 @@ std::string TransOpBreadthFusionPass::GetNodeId(const int anchor_index, const No | |||||
| trans_data_type = true; | trans_data_type = true; | ||||
| trans_format = true; | trans_format = true; | ||||
| trans_shape = true; | trans_shape = true; | ||||
| } else if (node->GetType() == RESHAPE) { | |||||
| } else if (node->GetType() == RESHAPE || node->GetType() == EXPANDDIMS || node->GetType() == SQUEEZE) { | |||||
| trans_shape = true; | trans_shape = true; | ||||
| } else if (node->GetType() == REFORMAT) { | |||||
| trans_format = true; | |||||
| } | } | ||||
| id << node->GetType() << '-' << anchor_index; | id << node->GetType() << '-' << anchor_index; | ||||
| @@ -1621,7 +1621,8 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) { | |||||
| for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { | for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { | ||||
| if (desc.GetShape().GetDim(i) < 0) { | if (desc.GetShape().GetDim(i) < 0) { | ||||
| std::string situation = "data dim[" + std::to_string(i) + "][" + std::to_string(desc.GetShape().GetDim(i)) + "]" ; | |||||
| std::string situation = "data dim[" + std::to_string(i) + "][" + | |||||
| std::to_string(desc.GetShape().GetDim(i)) + "]" ; | |||||
| std::string reason = "it need >= 0"; | std::string reason = "it need >= 0"; | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason}); | ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason}); | ||||
| GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i, | GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i, | ||||
| @@ -44,8 +44,6 @@ | |||||
| using std::set; | using std::set; | ||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| using std::map; | |||||
| using std::queue; | |||||
| namespace ge { | namespace ge { | ||||
| namespace multibatch { | namespace multibatch { | ||||
| @@ -59,15 +57,10 @@ const int kDataInIndex = 0; | |||||
| const int kMergeDataOutIndex = 0; | const int kMergeDataOutIndex = 0; | ||||
| const int kStaticOutput = -1; | const int kStaticOutput = -1; | ||||
| const int kDivisionConst = 2; | const int kDivisionConst = 2; | ||||
| const int32_t kOneInDataNode = 1; | |||||
| const int32_t kFindNoMatch = 0; | |||||
| inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } | inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } | ||||
| inline bool IsEnterType(const string &node_type) { return (node_type == ENTER) || (node_type == REFENTER); } | |||||
| const set<string> unchange_types({CONSTANT, CONSTANTOP, ENTER, REFENTER}); | |||||
| inline bool IsGetNextType(const NodePtr &node) { | inline bool IsGetNextType(const NodePtr &node) { | ||||
| std::string original_type; | std::string original_type; | ||||
| GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, | GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, | ||||
| @@ -225,6 +218,12 @@ Status MultiBatchGraphCopyer::CopyGraph() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = InsertIdentityAfterSwitchN(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGI("Begin to remove useless nodes by prune pass after copy process"); | GELOGI("Begin to remove useless nodes by prune pass after copy process"); | ||||
| PrunePass prune_pass; | PrunePass prune_pass; | ||||
| ret = prune_pass.Run(graph_); | ret = prune_pass.Run(graph_); | ||||
| @@ -241,18 +240,6 @@ Status MultiBatchGraphCopyer::Init() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = RelinkConstCtrlEdge(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Relink const's control edge failed."); | |||||
| return FAILED; | |||||
| } | |||||
| ret = ExtractUnchangedStructureOutofCycle(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Extract unchanged structure out of cycle failed."); | |||||
| return FAILED; | |||||
| } | |||||
| for (auto &node : graph_->GetAllNodes()) { | for (auto &node : graph_->GetAllNodes()) { | ||||
| origin_all_nodes_.emplace_back(node); | origin_all_nodes_.emplace_back(node); | ||||
| if (IsDataLikeType(node->GetType())) { | if (IsDataLikeType(node->GetType())) { | ||||
| @@ -265,281 +252,6 @@ Status MultiBatchGraphCopyer::Init() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MultiBatchGraphCopyer::RelinkConstCtrlEdge() { | |||||
| for (auto &node : graph_->GetAllNodes()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { | |||||
| if (node->GetOutDataNodes().empty()) { | |||||
| continue; | |||||
| } | |||||
| if (!node->GetInControlNodes().empty()) { | |||||
| auto in_ctrl_nodes = node->GetInControlNodes(); | |||||
| auto out_nodes = node->GetOutAllNodes(); | |||||
| bool has_merge = false; | |||||
| for (const auto &out_node : out_nodes) { | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| if (out_node->GetType() == MERGE || out_node->GetType() == REFMERGE) { | |||||
| has_merge = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (has_merge) { | |||||
| continue; | |||||
| } | |||||
| auto in_ctrl_anchor = node->GetInControlAnchor(); | |||||
| GE_CHECK_NOTNULL(in_ctrl_anchor); | |||||
| in_ctrl_anchor->UnlinkAll(); | |||||
| for (auto &in_ctrl_node : in_ctrl_nodes) { | |||||
| auto out_ctrl_anchor_of_in_ctrl_node = in_ctrl_node->GetOutControlAnchor(); | |||||
| GE_CHECK_NOTNULL(out_ctrl_anchor_of_in_ctrl_node); | |||||
| for (auto &out_node : out_nodes) { | |||||
| if (IsEnterType(out_node->GetType())) { | |||||
| continue; | |||||
| } | |||||
| if (!out_ctrl_anchor_of_in_ctrl_node->IsLinkedWith(out_node->GetInControlAnchor())) { | |||||
| GE_CHK_STATUS_RET(out_ctrl_anchor_of_in_ctrl_node->LinkTo(out_node->GetInControlAnchor())) | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| if (out_ctrl_anchor != nullptr) { | |||||
| out_ctrl_anchor->UnlinkAll(); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::ExtractUnchangedStructureOutofCycle() { | |||||
| map<string, vector<NodePtr>> frame_enter; | |||||
| if (GetEnterNodesGroupByFrame(frame_enter) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get enter nodes grouped by frame_name failed."); | |||||
| return FAILED; | |||||
| } | |||||
| queue<NodePtr> nodes_to_extract; | |||||
| if (GetNodeNeedExtract(frame_enter, nodes_to_extract) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get nodes needed to extract failed."); | |||||
| return FAILED; | |||||
| } | |||||
| while (!nodes_to_extract.empty()) { | |||||
| auto node = nodes_to_extract.front(); | |||||
| nodes_to_extract.pop(); | |||||
| OpDescPtr enter_desc = nullptr; | |||||
| if (MoveInEntersInDataAnchorDown(node, enter_desc) != SUCCESS) { | |||||
| GELOGE(FAILED, "Move in enter nodes' in data anchors down of %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| set<NodePtr> out_nodes; | |||||
| if (InsertEnterAfterNode(node, enter_desc, out_nodes) != SUCCESS) { | |||||
| GELOGE(FAILED, "Insert enter node after %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (MoveCtrlEdgeToOutNodes(node, out_nodes) != SUCCESS) { | |||||
| GELOGE(FAILED, "Move %s's control edge to out nodes failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| for (auto &out_node : out_nodes) { | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| if (AllInDataNodesUnchangeAndNoMergeOut(out_node)) { | |||||
| nodes_to_extract.push(out_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (DeleteEnterWithoutDataOut() != SUCCESS) { | |||||
| GELOGE(FAILED, "Delete enter node without out data nodes failed."); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::GetEnterNodesGroupByFrame(map<string, vector<NodePtr>> &frame_enter) { | |||||
| for (auto &node : graph_->GetAllNodes()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (IsEnterType(node->GetType())) { | |||||
| if (!node->GetInControlNodes().empty() || !node->GetOutControlNodes().empty()) { | |||||
| continue; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| string frame_name; | |||||
| if (!AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { | |||||
| GELOGE(FAILED, "Get attr frame_name of enter[%] failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| frame_enter[frame_name].emplace_back(node); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::GetNodeNeedExtract(const map<string, vector<NodePtr>> &frame_enter, | |||||
| queue<NodePtr> &nodes_to_extract) { | |||||
| for (const auto &one_group : frame_enter) { | |||||
| auto enters = one_group.second; | |||||
| for (const auto &enter : enters) { | |||||
| auto out_data_nodes = enter->GetOutDataNodes(); | |||||
| for (const auto &out_data_node : out_data_nodes) { | |||||
| GE_CHECK_NOTNULL(out_data_node); | |||||
| if (AllInDataNodesUnchangeAndNoMergeOut(out_data_node)) { | |||||
| nodes_to_extract.push(out_data_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool MultiBatchGraphCopyer::AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node) { | |||||
| auto out_data_nodes = node->GetOutDataNodes(); | |||||
| for (const auto &out_data_node : out_data_nodes) { | |||||
| if (out_data_node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (out_data_node->GetType() == MERGE || out_data_node->GetType() == REFMERGE) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto in_data_nodes = node->GetInDataNodes(); | |||||
| if (in_data_nodes.size() == kOneInDataNode) { | |||||
| return true; | |||||
| } | |||||
| for (const auto &in_data_node : in_data_nodes) { | |||||
| if (in_data_node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (unchange_types.count(in_data_node->GetType()) == kFindNoMatch) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc) { | |||||
| auto in_data_anchors = node->GetAllInDataAnchors(); | |||||
| for (auto &in_data_anchor : in_data_anchors) { | |||||
| auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_data_anchor); | |||||
| auto peer_in_data_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| if (IsEnterType(peer_in_data_node->GetType())) { | |||||
| GE_CHK_STATUS_RET(peer_out_data_anchor->Unlink(in_data_anchor)) | |||||
| GELOGD("Unlink data edge from %s to %s.", peer_in_data_node->GetName().c_str(), node->GetName().c_str()); | |||||
| auto enter_in_data_anchors = peer_in_data_node->GetAllInDataAnchors(); | |||||
| for (auto &enter_in_data_anchor : enter_in_data_anchors) { | |||||
| auto peer_out_data_anchor_of_enter = enter_in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_data_anchor_of_enter); | |||||
| if (peer_out_data_anchor_of_enter->IsLinkedWith(in_data_anchor)) { | |||||
| continue; | |||||
| } | |||||
| GE_CHK_STATUS_RET(peer_out_data_anchor_of_enter->LinkTo(in_data_anchor)) | |||||
| GELOGD("Relink data edge from %s to %s.", peer_out_data_anchor_of_enter->GetOwnerNode()->GetName().c_str(), | |||||
| node->GetName().c_str()); | |||||
| } | |||||
| enter_desc = peer_in_data_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(enter_desc); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::InsertEnterAfterNode(NodePtr &node, const OpDescPtr ©_desc, set<NodePtr> &out_nodes) { | |||||
| if (copy_desc == nullptr) { | |||||
| return SUCCESS; | |||||
| } | |||||
| map<OutDataAnchorPtr, vector<std::pair<InDataAnchorPtr, NodePtr>>> outanchors_inanchors_nodes; | |||||
| auto out_data_anchors = node->GetAllOutDataAnchors(); | |||||
| for (auto &out_data_anchor : out_data_anchors) { | |||||
| auto peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors(); | |||||
| for (auto peer_in_data_anchor : peer_in_data_anchors) { | |||||
| GE_CHECK_NOTNULL(peer_in_data_anchor); | |||||
| auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | |||||
| out_nodes.emplace(peer_in_data_node); | |||||
| outanchors_inanchors_nodes[out_data_anchor].emplace_back(std::make_pair(peer_in_data_anchor, peer_in_data_node)); | |||||
| } | |||||
| } | |||||
| int32_t i = 0; | |||||
| auto node_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(node_desc); | |||||
| // Insert one enter node after node's per out data anchor | |||||
| for (auto &outanchor_inanchors_nodes : outanchors_inanchors_nodes) { | |||||
| string name = node->GetName() + "_" + ENTER + "_" + std::to_string(i++); | |||||
| GELOGD("Create Enter op %s after %s.", name.c_str(), node->GetName().c_str()); | |||||
| auto enter_desc = AttrUtils::CopyOpDesc(copy_desc); | |||||
| enter_desc->SetName(name); | |||||
| GE_CHK_STATUS_RET( | |||||
| enter_desc->UpdateInputDesc("x", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) | |||||
| GE_CHK_STATUS_RET( | |||||
| enter_desc->UpdateOutputDesc("y", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) | |||||
| auto enter_node = graph_->AddNode(enter_desc); | |||||
| GE_CHECK_NOTNULL(enter_node); | |||||
| GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->LinkTo(enter_node->GetInDataAnchor(kDataInIndex))) | |||||
| GE_CHECK_NOTNULL(enter_node->GetOutDataAnchor(kDataInIndex)); | |||||
| for (auto &inanchor_node : outanchor_inanchors_nodes.second) { | |||||
| GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->Unlink(inanchor_node.first)) | |||||
| GE_CHK_STATUS_RET(enter_node->GetOutDataAnchor(kDataInIndex)->LinkTo(inanchor_node.first)) | |||||
| GELOGD("Unlink from %s to %s, link from %s to %s then to %s.", node->GetName().c_str(), | |||||
| inanchor_node.second->GetName().c_str(), node->GetName().c_str(), enter_node->GetName().c_str(), | |||||
| inanchor_node.second->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Move node's in control edges to out data nodes | |||||
| Status MultiBatchGraphCopyer::MoveCtrlEdgeToOutNodes(NodePtr &node, set<NodePtr> &out_nodes) { | |||||
| auto in_ctrl_anchor = node->GetInControlAnchor(); | |||||
| GE_CHECK_NOTNULL(in_ctrl_anchor); | |||||
| auto peer_out_ctrl_anchors = in_ctrl_anchor->GetPeerOutControlAnchors(); | |||||
| for (auto &peer_out_ctrl_anchor : peer_out_ctrl_anchors) { | |||||
| GE_CHK_STATUS_RET(peer_out_ctrl_anchor->Unlink(in_ctrl_anchor)) | |||||
| GELOGD("Unlink control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| node->GetName().c_str()); | |||||
| for (auto &out_node : out_nodes) { | |||||
| auto in_ctrl_anchor_of_out_node = out_node->GetInControlAnchor(); | |||||
| GE_CHECK_NOTNULL(in_ctrl_anchor_of_out_node); | |||||
| if (!peer_out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor_of_out_node)) { | |||||
| GE_CHK_STATUS_RET(peer_out_ctrl_anchor->LinkTo(in_ctrl_anchor_of_out_node)) | |||||
| GELOGD("Link control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| out_node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::DeleteEnterWithoutDataOut() { | |||||
| for (auto &node : graph_->GetAllNodes()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (IsEnterType(node->GetType())) { | |||||
| auto out_nodes = node->GetOutAllNodes(); | |||||
| if (out_nodes.empty()) { | |||||
| GELOGD("Delete enter node: %s which has no output.", node->GetName().c_str()); | |||||
| GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {})) | |||||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node)) | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) { | void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) { | ||||
| auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | ||||
| GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(), | GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(), | ||||
| @@ -585,9 +297,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||||
| LabelStatusForGetNextSink(data); | LabelStatusForGetNextSink(data); | ||||
| } | } | ||||
| } | } | ||||
| map<string, vector<NodePtr>> frame_enters; | |||||
| InitStatus(frame_enters); | |||||
| bool changed = true; | bool changed = true; | ||||
| // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch | // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch | ||||
| while (changed) { | while (changed) { | ||||
| @@ -597,13 +306,12 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||||
| if (iter != origin_nodes_status_.end()) { | if (iter != origin_nodes_status_.end()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| for (auto &in_node : node->GetInDataNodes()) { | |||||
| if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) { | |||||
| if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) { | |||||
| origin_nodes_status_[node.get()] == kNodeInBatchBranch; | |||||
| ResetEnterStatus(frame_enters, node); | |||||
| changed = true; | |||||
| } | |||||
| for (auto &in_node : node->GetInAllNodes()) { | |||||
| bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() && | |||||
| origin_nodes_status_[in_node.get()] == kNodeInBatchBranch; | |||||
| if (is_in_batch) { | |||||
| origin_nodes_status_[node.get()] = kNodeInBatchBranch; | |||||
| changed = true; | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -612,45 +320,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void MultiBatchGraphCopyer::InitStatus(map<string, vector<NodePtr>> &frame_enters) { | |||||
| for (const auto &node : origin_all_nodes_) { | |||||
| if (!IsEnterType(node->GetType())) { | |||||
| continue; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| string frame_name; | |||||
| if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { | |||||
| frame_enters[frame_name].emplace_back(node); | |||||
| } | |||||
| } | |||||
| for (const auto &data : origin_data_nodes_) { | |||||
| auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | |||||
| if (!IsAllDimsPositive(data_shape.GetDims())) { | |||||
| origin_nodes_status_[data.get()] = kNodeInBatchBranch; | |||||
| } | |||||
| } | |||||
| } | |||||
| void MultiBatchGraphCopyer::ResetEnterStatus(map<string, vector<NodePtr>> &frame_enters, const NodePtr &node) { | |||||
| if (!IsEnterType(node->GetType())) { | |||||
| return; | |||||
| } | |||||
| for (const auto &frame_enter : frame_enters) { | |||||
| auto &enters = frame_enter.second; | |||||
| if (std::find(enters.begin(), enters.end(), node) != enters.end()) { | |||||
| for (const auto &enter : enters) { | |||||
| origin_nodes_status_[enter.get()] = kNodeInBatchBranch; | |||||
| } | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| Status MultiBatchGraphCopyer::LabelStatus() { | Status MultiBatchGraphCopyer::LabelStatus() { | ||||
| if (LabelInBatchBranchStatus() != SUCCESS) { | if (LabelInBatchBranchStatus() != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "Failed to label no in batch branch"); | GELOGE(PARAM_INVALID, "Failed to label no in batch branch"); | ||||
| @@ -1691,6 +1360,52 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() { | |||||
| for (auto &node : graph_->GetAllNodes()) { | |||||
| if (node->GetType() != SWITCHN) { | |||||
| continue; | |||||
| } | |||||
| auto switchn_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(switchn_desc); | |||||
| size_t i = 0; | |||||
| for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
| for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| auto out_node = in_data_anchor->GetOwnerNode(); | |||||
| auto op_desc = out_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { | |||||
| GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto identity_desc = MakeShared<OpDesc>(node->GetName() + "_identity_" + std::to_string(i), IDENTITY); | |||||
| GE_CHECK_NOTNULL(identity_desc); | |||||
| string batch_label; | |||||
| if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
| if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
| GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| auto data_desc = switchn_desc->GetOutputDesc(i); | |||||
| i++; | |||||
| GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc)); | |||||
| GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc)); | |||||
| auto identity_node = graph_->AddNode(identity_desc); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0))); | |||||
| GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor()); | |||||
| GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor())); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ProcessMultiBatch(ComputeGraphPtr &graph) { | Status ProcessMultiBatch(ComputeGraphPtr &graph) { | ||||
| const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); | const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); | ||||
| if (multi_batch_with_case != nullptr) { | if (multi_batch_with_case != nullptr) { | ||||
| @@ -18,7 +18,6 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <queue> | #include <queue> | ||||
| #include <vector> | #include <vector> | ||||
| #include <set> | |||||
| #include "external/ge/ge_api_error_codes.h" | #include "external/ge/ge_api_error_codes.h" | ||||
| @@ -65,26 +64,12 @@ class MultiBatchGraphCopyer { | |||||
| private: | private: | ||||
| Status Init(); | Status Init(); | ||||
| Status CheckArguments(); | Status CheckArguments(); | ||||
| Status RelinkConstCtrlEdge(); | |||||
| Status ExtractUnchangedStructureOutofCycle(); | |||||
| Status GetEnterNodesGroupByFrame(std::map<std::string, std::vector<NodePtr>> &frame_enter); | |||||
| Status GetNodeNeedExtract(const std::map<std::string, std::vector<NodePtr>> &frame_enter, | |||||
| std::queue<NodePtr> &nodes_to_extract); | |||||
| bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node); | |||||
| Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc); | |||||
| Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set<NodePtr> &out_nodes); | |||||
| Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set<NodePtr> &out_nodes); | |||||
| Status DeleteEnterWithoutDataOut(); | |||||
| // label status for origin_all_nodes_ | // label status for origin_all_nodes_ | ||||
| Status LabelStatus(); | Status LabelStatus(); | ||||
| Status LabelInBatchBranchStatus(); | Status LabelInBatchBranchStatus(); | ||||
| void LabelStatusForData(const NodePtr &data); | void LabelStatusForData(const NodePtr &data); | ||||
| void LabelStatusForGetNextSink(const NodePtr &data); | void LabelStatusForGetNextSink(const NodePtr &data); | ||||
| void InitStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters); | |||||
| void ResetEnterStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters, const NodePtr &node); | |||||
| // add nodes functions | // add nodes functions | ||||
| Status CreateNewNodes(); | Status CreateNewNodes(); | ||||
| @@ -96,6 +81,7 @@ class MultiBatchGraphCopyer { | |||||
| Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index, | Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index, | ||||
| std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn); | std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn); | ||||
| Status InsertIdentityAfterSwitchN(); | |||||
| Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index); | Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index); | ||||
| Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); | Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); | ||||
| @@ -180,8 +180,12 @@ Status SsdPriorboxKernel::SetVariance(const vector<float> &variance, const int d | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SsdPriorboxKernel::GetNumPriorAndDimSize(uint32_t aspect_ratios_size, uint32_t min_sizes_size, uint32_t max_sizes_size, | |||||
| int layer_width, int layer_height, int &num_priors, | |||||
| Status SsdPriorboxKernel::GetNumPriorAndDimSize(uint32_t aspect_ratios_size, | |||||
| uint32_t min_sizes_size, | |||||
| uint32_t max_sizes_size, | |||||
| int layer_width, | |||||
| int layer_height, | |||||
| int &num_priors, | |||||
| int &dim_size) const { | int &dim_size) const { | ||||
| if (ge::CheckUint32MulOverflow(min_sizes_size, aspect_ratios_size) != SUCCESS) { | if (ge::CheckUint32MulOverflow(min_sizes_size, aspect_ratios_size) != SUCCESS) { | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -379,11 +379,13 @@ Status HybridModelAsyncExecutor::Execute(const std::vector<DataBuffer> &inputs, | |||||
| } | } | ||||
| if (output_real_size > 0) { | if (output_real_size > 0) { | ||||
| if (outputs[i].length < static_cast<uint64_t>(output_real_size)) { | if (outputs[i].length < static_cast<uint64_t>(output_real_size)) { | ||||
| GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by user should be greater than or equal to the real size of output[%ld]", | |||||
| GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by " | |||||
| "user should be greater than or equal to the real size of output[%ld]", | |||||
| i, outputs[i].length, output_real_size); | i, outputs[i].length, output_real_size); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size, RT_MEMCPY_DEVICE_TO_DEVICE)); | |||||
| GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, | |||||
| args.outputs[i].GetData(), output_real_size, RT_MEMCPY_DEVICE_TO_DEVICE)); | |||||
| } | } | ||||
| outputs[i].length = output_real_size; | outputs[i].length = output_real_size; | ||||
| } | } | ||||
| @@ -62,7 +62,8 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| { | { | ||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | ||||
| GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), "Invoke InferShapeAndType failed."); | |||||
| GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | |||||
| "Invoke InferShapeAndType failed."); | |||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); | ||||
| } | } | ||||
| // Check again to make sure shape is valid after shape inference | // Check again to make sure shape is valid after shape inference | ||||
| @@ -176,7 +176,8 @@ Status HybridModel::GetInputOutputDescInfo(vector<InputOutputDescInfo> &input_de | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void HybridModel::SetInputDimsAndShapeRangesInfo(const vector<int64_t> &model_input_dims, std::vector<std::pair<int64_t,int64_t>> &shape_ranges, | |||||
| void HybridModel::SetInputDimsAndShapeRangesInfo(const vector<int64_t> &model_input_dims, | |||||
| std::vector<std::pair<int64_t, int64_t>> &shape_ranges, | |||||
| InputOutputDescInfo &input) { | InputOutputDescInfo &input) { | ||||
| for (auto model_input_dim : model_input_dims) { | for (auto model_input_dim : model_input_dims) { | ||||
| input.shape_info.dims.push_back(model_input_dim); | input.shape_info.dims.push_back(model_input_dim); | ||||
| @@ -245,7 +246,8 @@ Status HybridModel::GetInputDescInfo(vector<InputOutputDescInfo> &input_desc, st | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, InputOutputDescInfo &output_desc_info, uint32_t &format_result) { | |||||
| void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, | |||||
| InputOutputDescInfo &output_desc_info, uint32_t &format_result) { | |||||
| GE_IF_BOOL_EXEC(output_desc == nullptr, GELOGE(FAILED, "output desc ptr is nullptr"); return ); | GE_IF_BOOL_EXEC(output_desc == nullptr, GELOGE(FAILED, "output desc ptr is nullptr"); return ); | ||||
| Format format = output_desc->GetFormat(); | Format format = output_desc->GetFormat(); | ||||
| GeShape shape = output_desc->GetShape(); | GeShape shape = output_desc->GetShape(); | ||||
| @@ -283,7 +285,8 @@ void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, InputOutputDes | |||||
| Status HybridModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, std::vector<uint32_t> &formats) { | Status HybridModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, std::vector<uint32_t> &formats) { | ||||
| std::vector<ConstGeTensorDescPtr> output_desc_list; | std::vector<ConstGeTensorDescPtr> output_desc_list; | ||||
| GE_CHK_STATUS_RET(root_graph_item_->GetOutputDescList(output_desc_list), "get output desc info failed"); // output_desc_list contains vaild input desc | |||||
| // output_desc_list contains vaild input desc | |||||
| GE_CHK_STATUS_RET(root_graph_item_->GetOutputDescList(output_desc_list), "get output desc info failed"); | |||||
| vector<std::string> out_node_names; | vector<std::string> out_node_names; | ||||
| (void)ge::AttrUtils::GetListStr(ge_root_model_->GetRootGraph(), ATTR_MODEL_OUT_NODES_NAME, out_node_names); | (void)ge::AttrUtils::GetListStr(ge_root_model_->GetRootGraph(), ATTR_MODEL_OUT_NODES_NAME, out_node_names); | ||||
| @@ -293,7 +296,8 @@ Status HybridModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, | |||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| auto out_size = static_cast<uint32_t>(op_desc->GetInputsSize()); | auto out_size = static_cast<uint32_t>(op_desc->GetInputsSize()); | ||||
| GE_CHK_BOOL_RET_STATUS(out_size == output_desc_list.size(), FAILED, "output size[%u] not match output_desc_list size[%zu]", out_size, output_desc_list.size()); | |||||
| GE_CHK_BOOL_RET_STATUS(out_size == output_desc_list.size(), | |||||
| FAILED, "output size[%u] not match output_desc_list size[%zu]", out_size, output_desc_list.size()); | |||||
| for (uint32_t index = 0; index < out_size; ++index) { | for (uint32_t index = 0; index < out_size; ++index) { | ||||
| string output_name; | string output_name; | ||||
| @@ -301,9 +305,11 @@ Status HybridModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, | |||||
| std::vector<int64_t> src_index = op_desc->GetSrcIndex(); | std::vector<int64_t> src_index = op_desc->GetSrcIndex(); | ||||
| if (out_size == out_node_names.size()) { | if (out_size == out_node_names.size()) { | ||||
| bool contains_colon = out_node_names[index].find(":") != std::string::npos; | bool contains_colon = out_node_names[index].find(":") != std::string::npos; | ||||
| output_name = contains_colon ? out_node_names[index] : out_node_names[index] + ":" + std::to_string(src_index[index]); | |||||
| output_name = contains_colon ? out_node_names[index] : out_node_names[index] + | |||||
| ":" + std::to_string(src_index[index]); | |||||
| } else { | } else { | ||||
| output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + std::to_string(src_index[index]); | |||||
| output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + | |||||
| "_" + std::to_string(src_index[index]); | |||||
| } | } | ||||
| InputOutputDescInfo output_desc_info; | InputOutputDescInfo output_desc_info; | ||||
| @@ -104,7 +104,8 @@ class HybridModel { | |||||
| void SetModelDescVersion(bool is_new_model_desc) { is_new_model_desc_ = is_new_model_desc; } | void SetModelDescVersion(bool is_new_model_desc) { is_new_model_desc_ = is_new_model_desc; } | ||||
| void SetInputDimsAndShapeRangesInfo(const vector<int64_t> &model_input_dims, std::vector<std::pair<int64_t, int64_t>> &shape_ranges, | |||||
| void SetInputDimsAndShapeRangesInfo(const vector<int64_t> &model_input_dims, | |||||
| std::vector<std::pair<int64_t, int64_t>> &shape_ranges, | |||||
| InputOutputDescInfo &input); | InputOutputDescInfo &input); | ||||
| private: | private: | ||||
| @@ -36,7 +36,6 @@ | |||||
| #include "model/ge_model.h" | #include "model/ge_model.h" | ||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
| #include "graph/utils/type_utils.h" | |||||
| using std::string; | using std::string; | ||||
| using namespace std; | using namespace std; | ||||
| @@ -50,11 +49,8 @@ const std::string IR_OPTION_LOG_LEVEL_DEFAULT = "default"; | |||||
| const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; | const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; | ||||
| const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | ||||
| const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | ||||
| const std::string kInputShape = "input_shape"; | const std::string kInputShape = "input_shape"; | ||||
| const std::string kInputFormat = "input_format"; | const std::string kInputFormat = "input_format"; | ||||
| const std::string kReUseMemEnable = "1"; | |||||
| const std::string kReUseMemDisEnable = "0"; | |||||
| } // namespace | } // namespace | ||||
| static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) { | static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) { | ||||
| @@ -232,12 +228,12 @@ class Impl { | |||||
| graphStatus CheckOptions(const std::map<std::string, std::string> &options); | graphStatus CheckOptions(const std::map<std::string, std::string> &options); | ||||
| graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs); | graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs); | ||||
| graphStatus GetDefaultInputShape(const Graph &graph, string &default_shape); | graphStatus GetDefaultInputShape(const Graph &graph, string &default_shape); | ||||
| graphStatus UpdateDataOpAttr(const Graph &graph); | |||||
| graphStatus Init(const Graph &graph, const std::map<std::string, std::string> &options); | graphStatus Init(const Graph &graph, const std::map<std::string, std::string> &options); | ||||
| graphStatus BuildModel(const Graph &graph, const std::map<std::string, std::string> &options, | graphStatus BuildModel(const Graph &graph, const std::map<std::string, std::string> &options, | ||||
| ModelBufferData &ge_models); | ModelBufferData &ge_models); | ||||
| graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, | graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, | ||||
| bool is_dynamic_input); | bool is_dynamic_input); | ||||
| graphStatus UpdateDataOpAttr(const Graph &graph); | |||||
| void SetRtSocVersion(); | void SetRtSocVersion(); | ||||
| void UpdateThreadContext(); | void UpdateThreadContext(); | ||||
| void LoadOpsProto(); | void LoadOpsProto(); | ||||
| @@ -429,6 +425,7 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||||
| // for IR builder.Only support om mode, so here fixed; | // for IR builder.Only support om mode, so here fixed; | ||||
| options_.insert(std::pair<string, string>(string(IR_OPTION_MODE), to_string(0))); | options_.insert(std::pair<string, string>(string(IR_OPTION_MODE), to_string(0))); | ||||
| options_.insert(std::pair<string, string>(string(IR_OPTION_TARGET), "mini")); | |||||
| options_.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0))); | options_.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0))); | ||||
| options_.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0))); | options_.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0))); | ||||
| options_.insert(std::pair<string, string>(string(ge::SAVE_ORIGINAL_MODEL), to_string(0))); | options_.insert(std::pair<string, string>(string(ge::SAVE_ORIGINAL_MODEL), to_string(0))); | ||||
| @@ -468,52 +465,39 @@ void Impl::UpdateThreadContext() { | |||||
| graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs) { | graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs) { | ||||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| int64_t index = 0; | |||||
| for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { | for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { | ||||
| GE_CHECK_NOTNULL(input_node); | GE_CHECK_NOTNULL(input_node); | ||||
| ge::OpDescPtr op = input_node->GetOpDesc(); | ge::OpDescPtr op = input_node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
| if (op->GetType() == DATA) { | if (op->GetType() == DATA) { | ||||
| (void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++); | |||||
| GELOGD("Data op inputDesc size: %zu", op->GetAllInputsDesc().size()); | GELOGD("Data op inputDesc size: %zu", op->GetAllInputsDesc().size()); | ||||
| auto tensor = op->MutableInputDesc(0); | |||||
| GE_CHECK_NOTNULL(tensor); | |||||
| ge::GeTensorDesc tensor = op->GetInputDesc(0); | |||||
| string data_op_name = op->GetName(); | string data_op_name = op->GetName(); | ||||
| GELOGD("Data op name: %s", data_op_name.c_str()); | GELOGD("Data op name: %s", data_op_name.c_str()); | ||||
| ge::GeShape data_shape; | ge::GeShape data_shape; | ||||
| auto iter = omg_context_.input_dims.find(data_op_name); | auto iter = omg_context_.input_dims.find(data_op_name); | ||||
| if (iter != omg_context_.input_dims.end()) { | if (iter != omg_context_.input_dims.end()) { | ||||
| data_shape = ge::GeShape(iter->second); | data_shape = ge::GeShape(iter->second); | ||||
| GELOGD("Data op get shape from Context and update [%s] shape info", data_op_name.c_str()); | |||||
| GELOGD("Data op get shape from Context."); | |||||
| } else { | } else { | ||||
| data_shape = tensor->GetShape(); | |||||
| data_shape = tensor.GetShape(); | |||||
| GELOGD("Data op get shape from InputDesc in ge ir graph."); | GELOGD("Data op get shape from InputDesc in ge ir graph."); | ||||
| } | } | ||||
| // If user point input format, do work for all data ops; else do according to tensor_desc | // If user point input format, do work for all data ops; else do according to tensor_desc | ||||
| auto data_format = omg_context_.format != domi::DOMI_TENSOR_ND ? | auto data_format = omg_context_.format != domi::DOMI_TENSOR_ND ? | ||||
| ge::TypeUtils::DomiFormatToFormat(omg_context_.format) : tensor->GetFormat(); | |||||
| ge::DataType data_type = tensor->GetDataType(); | |||||
| ge::TypeUtils::DomiFormatToFormat(omg_context_.format) : tensor.GetFormat(); | |||||
| ge::DataType data_type = tensor.GetDataType(); | |||||
| string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type); | string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type); | ||||
| GELOGD("Data op get data type:%s from InputDesc in ge ir graph.", data_type_str.c_str()); | GELOGD("Data op get data type:%s from InputDesc in ge ir graph.", data_type_str.c_str()); | ||||
| ge::GeTensor inputTensor; | ge::GeTensor inputTensor; | ||||
| ge::GeTensorDesc desc(data_shape, ge::Format(data_format), data_type); | ge::GeTensorDesc desc(data_shape, ge::Format(data_format), data_type); | ||||
| inputTensor.SetTensorDesc(desc); | inputTensor.SetTensorDesc(desc); | ||||
| int64_t index = 0; | |||||
| if (AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) { | |||||
| AttrUtils::SetInt(desc, ATTR_NAME_INDEX, index); | |||||
| } else { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Get attr name idx failed!"); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| inputs.emplace_back(inputTensor); | |||||
| inputs.push_back(inputTensor); | |||||
| } | } | ||||
| } | } | ||||
| std::sort(inputs.begin(), inputs.end(), [](ge::GeTensor a, ge::GeTensor b) { | |||||
| int64_t data_idx_a = 0; | |||||
| int64_t data_idx_b = 0; | |||||
| AttrUtils::GetInt(a.MutableTensorDesc(), ATTR_NAME_INDEX, data_idx_a); | |||||
| AttrUtils::GetInt(b.MutableTensorDesc(), ATTR_NAME_INDEX, data_idx_b); | |||||
| return data_idx_a <= data_idx_b; | |||||
| }); | |||||
| GELOGD("CreateInputsForIRBuild, inputs size: %zu", inputs.size()); | GELOGD("CreateInputsForIRBuild, inputs size: %zu", inputs.size()); | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -606,7 +590,7 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m | |||||
| GELOGE(GRAPH_PARAM_INVALID, "input model is illegal"); | GELOGE(GRAPH_PARAM_INVALID, "input model is illegal"); | ||||
| return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
| } | } | ||||
| return FileSaver::SaveToFile((output_file + ".om"), reinterpret_cast<void*>(model.data.get()), | |||||
| return FileSaver::SaveToFile((output_file + ".om"), reinterpret_cast<void *>(model.data.get()), | |||||
| static_cast<uint32_t>(model.length)); | static_cast<uint32_t>(model.length)); | ||||
| } | } | ||||
| @@ -621,7 +605,7 @@ graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &mod | |||||
| return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
| } | } | ||||
| std::string str_output_file = output_file; | std::string str_output_file = output_file; | ||||
| return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast<void*>(model.data.get()), | |||||
| return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast<void *>(model.data.get()), | |||||
| static_cast<uint32_t>(model.length)); | static_cast<uint32_t>(model.length)); | ||||
| } | } | ||||
| @@ -74,22 +74,22 @@ target_link_libraries(atc PRIVATE | |||||
| -ldl | -ldl | ||||
| ) | ) | ||||
| ############ atc.bin ############ | |||||
| add_executable(atc.bin ${SRC_LIST} ${PROTO_HDRS}) | |||||
| ############ atc_atc.bin ############ | |||||
| add_executable(atc_atc.bin ${SRC_LIST} ${PROTO_HDRS}) | |||||
| target_compile_options(atc.bin PRIVATE | |||||
| target_compile_options(atc_atc.bin PRIVATE | |||||
| -Werror | -Werror | ||||
| -O2 | -O2 | ||||
| -Wno-deprecated-declarations | -Wno-deprecated-declarations | ||||
| ) | ) | ||||
| target_compile_definitions(atc.bin PRIVATE | |||||
| target_compile_definitions(atc_atc.bin PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
| COMPILE_OMG_PACKAGE | COMPILE_OMG_PACKAGE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_include_directories(atc.bin PRIVATE | |||||
| target_include_directories(atc_atc.bin PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | ${CMAKE_CURRENT_LIST_DIR} | ||||
| ${GE_CODE_DIR} | ${GE_CODE_DIR} | ||||
| ${GE_CODE_DIR}/ge | ${GE_CODE_DIR}/ge | ||||
| @@ -115,7 +115,7 @@ target_include_directories(atc.bin PRIVATE | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ||||
| ) | ) | ||||
| target_link_libraries(atc.bin PRIVATE | |||||
| target_link_libraries(atc_atc.bin PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ascend_protobuf | ascend_protobuf | ||||
| ge_common | ge_common | ||||
| @@ -134,6 +134,11 @@ target_link_libraries(atc.bin PRIVATE | |||||
| -ldl | -ldl | ||||
| ) | ) | ||||
| set_target_properties(atc_atc.bin PROPERTIES | |||||
| OUTPUT_NAME atc.bin | |||||
| RUNTIME_OUTPUT_DIRECTORY atclib | |||||
| ) | |||||
| ############ fwk_atc.bin ############ | ############ fwk_atc.bin ############ | ||||
| add_executable(fwk_atc.bin ${SRC_LIST} ${PROTO_HDRS}) | add_executable(fwk_atc.bin ${SRC_LIST} ${PROTO_HDRS}) | ||||
| @@ -194,10 +199,23 @@ target_link_libraries(fwk_atc.bin PRIVATE | |||||
| -ldl | -ldl | ||||
| ) | ) | ||||
| set_target_properties(fwk_atc.bin PROPERTIES | |||||
| OUTPUT_NAME atc.bin | |||||
| RUNTIME_OUTPUT_DIRECTORY fwkacl | |||||
| ) | |||||
| ############ install ############ | ############ install ############ | ||||
| set(INSTALL_BASE_DIR "") | set(INSTALL_BASE_DIR "") | ||||
| set(INSTALL_LIBRARY_DIR lib) | set(INSTALL_LIBRARY_DIR lib) | ||||
| install(TARGETS atc atc.bin fwk_atc.bin OPTIONAL | |||||
| install(TARGETS atc OPTIONAL | |||||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | ||||
| ) | ) | ||||
| install(TARGETS atc_atc.bin OPTIONAL | |||||
| RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/atclib | |||||
| ) | |||||
| install(TARGETS fwk_atc.bin OPTIONAL | |||||
| RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/fwkacl | |||||
| ) | |||||
| @@ -4,7 +4,12 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. | # Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. | ||||
| #------------------------------------------------------------------- | #------------------------------------------------------------------- | ||||
| LOCAL_PATH=$(cd "$(dirname "$0")"; pwd) | |||||
| real_path=$(readlink "$0") | |||||
| if [ $? -eq 0 ]; then | |||||
| LOCAL_PATH=$(cd "$(dirname "$real_path")"; pwd) | |||||
| else | |||||
| LOCAL_PATH=$(cd "$(dirname "$0")"; pwd) | |||||
| fi | |||||
| PKG_PATH=$(cd ${LOCAL_PATH}/..; pwd) | PKG_PATH=$(cd ${LOCAL_PATH}/..; pwd) | ||||
| LIB_P="/lib64" | LIB_P="/lib64" | ||||
| PYTHON_P="/python/site-packages" | PYTHON_P="/python/site-packages" | ||||
| @@ -13,8 +18,4 @@ PYTHON_PATH="${PKG_PATH}${PYTHON_P}" | |||||
| export LD_LIBRARY_PATH="${LIB64_PATH}:${LD_LIBRARY_PATH}" | export LD_LIBRARY_PATH="${LIB64_PATH}:${LD_LIBRARY_PATH}" | ||||
| export PYTHONPATH="${PYTHON_PATH}:${PYTHONPATH}" | export PYTHONPATH="${PYTHON_PATH}:${PYTHONPATH}" | ||||
| if [ -f "${PKG_PATH}/bin/atc.bin" ];then | |||||
| ${PKG_PATH}/bin/atc.bin/atc.bin $@ | |||||
| else | |||||
| ${PKG_PATH}/bin/atc.bin/fwk_atc.bin $@ | |||||
| fi | |||||
| ${PKG_PATH}/bin/atc.bin "$@" | |||||
| @@ -56,7 +56,7 @@ include $(BUILD_HOST_EXECUTABLE) | |||||
| include $(CLEAR_VARS) | include $(CLEAR_VARS) | ||||
| LOCAL_MODULE := atc.bin | |||||
| LOCAL_MODULE := atclib/atc.bin | |||||
| LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations | LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations | ||||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private | LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private | ||||
| @@ -109,7 +109,7 @@ include $(BUILD_HOST_EXECUTABLE) | |||||
| include $(CLEAR_VARS) | include $(CLEAR_VARS) | ||||
| LOCAL_MODULE := fwk_atc.bin | |||||
| LOCAL_MODULE := fwkacl/atc.bin | |||||
| LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations | LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations | ||||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private | LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/utils/type_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/operator_factory_impl.h" | #include "graph/operator_factory_impl.h" | ||||
| @@ -176,6 +177,7 @@ T GetValue(const map<string, T> &dict, string &key, T default_val) { | |||||
| } | } | ||||
| void from_json(const Json &j, SingleOpTensorDesc &desc) { | void from_json(const Json &j, SingleOpTensorDesc &desc) { | ||||
| bool is_tensor_valid = true; | |||||
| desc.dims = j.at(kKeyShape).get<vector<int64_t>>(); | desc.dims = j.at(kKeyShape).get<vector<int64_t>>(); | ||||
| auto it = j.find(kKeyShapeRange); | auto it = j.find(kKeyShapeRange); | ||||
| if (it != j.end()) { | if (it != j.end()) { | ||||
| @@ -189,9 +191,12 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { | |||||
| string type_str = j.at(kKeyType).get<string>(); | string type_str = j.at(kKeyType).get<string>(); | ||||
| desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); | desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); | ||||
| desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); | desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); | ||||
| is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(format_str); | |||||
| is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsDataTypeValid(type_str); | |||||
| it = j.find(kKeyOriginFormat); | it = j.find(kKeyOriginFormat); | ||||
| if (it != j.end()) { | if (it != j.end()) { | ||||
| string origin_format_str = j.at(kKeyOriginFormat).get<string>(); | string origin_format_str = j.at(kKeyOriginFormat).get<string>(); | ||||
| is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(origin_format_str); | |||||
| desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED); | desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED); | ||||
| } | } | ||||
| auto tensor_name = j.find(kKeyName); | auto tensor_name = j.find(kKeyName); | ||||
| @@ -202,6 +207,9 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { | |||||
| if (dynamic_input_name != j.end()) { | if (dynamic_input_name != j.end()) { | ||||
| desc.dynamic_input_name = dynamic_input_name->get<string>(); | desc.dynamic_input_name = dynamic_input_name->get<string>(); | ||||
| } | } | ||||
| if (!is_tensor_valid) { | |||||
| desc.SetValidFlag(is_tensor_valid); | |||||
| } | |||||
| } | } | ||||
| void from_json(const Json &j, SingleOpAttr &attr) { | void from_json(const Json &j, SingleOpAttr &attr) { | ||||
| @@ -305,6 +313,12 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { | |||||
| int index = 0; | int index = 0; | ||||
| for (auto &tensor_desc : op_desc.input_desc) { | for (auto &tensor_desc : op_desc.input_desc) { | ||||
| if (!tensor_desc.GetValidFlag()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, | |||||
| {"intput", "datatype or format", std::to_string(index)}); | |||||
| GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index); | |||||
| return false; | |||||
| } | |||||
| if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) || | if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) || | ||||
| (tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){ | (tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){ | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, | ||||
| @@ -317,6 +331,12 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { | |||||
| index = 0; | index = 0; | ||||
| for (auto &tensor_desc : op_desc.output_desc) { | for (auto &tensor_desc : op_desc.output_desc) { | ||||
| if (!tensor_desc.GetValidFlag()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, | |||||
| {"output", "datatype", std::to_string(index)}); | |||||
| GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); | |||||
| return false; | |||||
| } | |||||
| if (tensor_desc.type == DT_UNDEFINED) { | if (tensor_desc.type == DT_UNDEFINED) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, | ||||
| {"output", "datatype", std::to_string(index)}); | {"output", "datatype", std::to_string(index)}); | ||||
| @@ -28,6 +28,10 @@ | |||||
| namespace ge { | namespace ge { | ||||
| struct SingleOpTensorDesc { | struct SingleOpTensorDesc { | ||||
| public: | |||||
| bool GetValidFlag() const { return is_valid_; } | |||||
| void SetValidFlag(bool is_valid) { is_valid_ = is_valid; } | |||||
| public: | |||||
| std::string name; | std::string name; | ||||
| std::vector<int64_t> dims; | std::vector<int64_t> dims; | ||||
| std::vector<int64_t> ori_dims; | std::vector<int64_t> ori_dims; | ||||
| @@ -36,6 +40,8 @@ struct SingleOpTensorDesc { | |||||
| ge::Format ori_format = ge::FORMAT_RESERVED; | ge::Format ori_format = ge::FORMAT_RESERVED; | ||||
| ge::DataType type = ge::DT_UNDEFINED; | ge::DataType type = ge::DT_UNDEFINED; | ||||
| std::string dynamic_input_name; | std::string dynamic_input_name; | ||||
| private: | |||||
| bool is_valid_ = true; | |||||
| }; | }; | ||||
| struct SingleOpAttr { | struct SingleOpAttr { | ||||
| @@ -175,8 +175,8 @@ Status OpsKernelManager::ParsePluginOptions(const map<string, string> &options, | |||||
| } else if (flag == 1) { | } else if (flag == 1) { | ||||
| enable_flag = true; | enable_flag = true; | ||||
| } else { | } else { | ||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", plugin_name.c_str(), | |||||
| iter->second.c_str()); | |||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", | |||||
| plugin_name.c_str(), iter->second.c_str()); | |||||
| return GE_GRAPH_OPTIONS_INVALID; | return GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| } catch (std::invalid_argument &) { | } catch (std::invalid_argument &) { | ||||
| @@ -188,8 +188,8 @@ Status OpsKernelManager::ParsePluginOptions(const map<string, string> &options, | |||||
| iter->second.c_str()); | iter->second.c_str()); | ||||
| return GE_GRAPH_OPTIONS_INVALID; | return GE_GRAPH_OPTIONS_INVALID; | ||||
| } catch (...) { | } catch (...) { | ||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", plugin_name.c_str(), | |||||
| iter->second.c_str()); | |||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", | |||||
| plugin_name.c_str(), iter->second.c_str()); | |||||
| return GE_GRAPH_OPTIONS_INVALID; | return GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -644,7 +644,8 @@ Status ParseOutNodes(const string &out_nodes) { | |||||
| if (!domi::GetContext().user_out_nodes_top_vec.empty()) { | if (!domi::GetContext().user_out_nodes_top_vec.empty()) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ||||
| {"--out_nodes", out_nodes, "is not all index or top_name"}); | {"--out_nodes", out_nodes, "is not all index or top_name"}); | ||||
| GELOGE(PARAM_INVALID, "This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str()); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| // stoi: The method may throw an exception: invalid_argument/out_of_range | // stoi: The method may throw an exception: invalid_argument/out_of_range | ||||
| @@ -111,7 +111,8 @@ Status SingleOp::ValidateArgs(const std::vector<DataBuffer> &inputs, const std:: | |||||
| auto num_outputs = outputs.size(); | auto num_outputs = outputs.size(); | ||||
| if (num_outputs != output_sizes_.size()) { | if (num_outputs != output_sizes_.size()) { | ||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "output num mismatch. model expect %zu, but given %zu", output_sizes_.size(), outputs.size()); | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "output num mismatch. model expect %zu, but given %zu", | |||||
| output_sizes_.size(), outputs.size()); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | return ACL_ERROR_GE_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -268,7 +268,8 @@ Status SingleOpModel::BuildTaskList(StreamResource *stream_resource, SingleOp &s | |||||
| ParseArgTable(task, single_op); | ParseArgTable(task, single_op); | ||||
| single_op.tasks_.emplace_back(task); | single_op.tasks_.emplace_back(task); | ||||
| } else { | } else { | ||||
| GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, "Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", context.kernel_type()); | |||||
| GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, | |||||
| "Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", context.kernel_type()); | |||||
| return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID; | return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID; | ||||
| } | } | ||||
| } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { | } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { | ||||
| @@ -173,7 +173,8 @@ Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam | |||||
| auto tbe_kernel = GetTbeKernel(op_desc_); | auto tbe_kernel = GetTbeKernel(op_desc_); | ||||
| if (tbe_kernel == nullptr) { | if (tbe_kernel == nullptr) { | ||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "OP EXT ATTR NAME TBE_KERNEL not found. op = %s", op_desc_->GetName().c_str()); | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "OP EXT ATTR NAME TBE_KERNEL not found. op = %s", | |||||
| op_desc_->GetName().c_str()); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | return ACL_ERROR_GE_INTERNAL_ERROR; | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| #define CC_FUSION_OP_MAX 32 | |||||
| const int CC_FUSION_OP_MAX = 32; | |||||
| typedef enum tagCcStatus { | typedef enum tagCcStatus { | ||||
| CC_STATUS_SUCCESS = 0, /**< succ */ | CC_STATUS_SUCCESS = 0, /**< succ */ | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit d19c9c5c92f21a0335c18681dcceed44f3a54ddc | |||||
| Subproject commit bd2cfdfa85a3d9dcbd7dc825f5759c7f8b3ffa9a | |||||