| @@ -954,8 +954,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::CallMs | |||||
| static_cast<void *>(&reporter_data), sizeof(ReporterData)); | static_cast<void *>(&reporter_data), sizeof(ReporterData)); | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInputOutputInfo( | |||||
| const OpDescPtr &op, TaskDescInfo &task_desc_info) const { | |||||
| void ProfilingManager::GetOpInputInfo(const OpDescPtr &op, TaskDescInfo &task_desc_info) const { | |||||
| std::vector<Format> input_format; | std::vector<Format> input_format; | ||||
| std::vector<std::vector<int64_t>> input_shape; | std::vector<std::vector<int64_t>> input_shape; | ||||
| std::vector<DataType> input_data_type; | std::vector<DataType> input_data_type; | ||||
| @@ -968,6 +967,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInp | |||||
| input_shape.emplace_back(input_tensor_desc->GetShape().GetDims()); | input_shape.emplace_back(input_tensor_desc->GetShape().GetDims()); | ||||
| input_data_type.emplace_back(input_tensor_desc->GetDataType()); | input_data_type.emplace_back(input_tensor_desc->GetDataType()); | ||||
| } | } | ||||
| std::vector<Format> format_default = { FORMAT_NULL }; | |||||
| std::vector<std::vector<int64_t>> shape_default = { {0} }; | |||||
| std::vector<DataType> data_type_default = { DT_UNDEFINED }; | |||||
| task_desc_info.input_format = input_format.empty() ? format_default : input_format; | |||||
| task_desc_info.input_shape = input_shape.empty() ? shape_default : input_shape; | |||||
| task_desc_info.input_data_type = input_data_type.empty() ? data_type_default : input_data_type; | |||||
| } | |||||
| void ProfilingManager::GetOpOutputInfo(const OpDescPtr &op, TaskDescInfo &task_desc_info) const { | |||||
| std::vector<Format> output_format; | std::vector<Format> output_format; | ||||
| std::vector<std::vector<int64_t>> output_shape; | std::vector<std::vector<int64_t>> output_shape; | ||||
| std::vector<DataType> output_data_type; | std::vector<DataType> output_data_type; | ||||
| @@ -984,14 +993,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInp | |||||
| std::vector<Format> format_default = { FORMAT_NULL }; | std::vector<Format> format_default = { FORMAT_NULL }; | ||||
| std::vector<std::vector<int64_t>> shape_default = { {0} }; | std::vector<std::vector<int64_t>> shape_default = { {0} }; | ||||
| std::vector<DataType> data_type_default = { DT_UNDEFINED }; | std::vector<DataType> data_type_default = { DT_UNDEFINED }; | ||||
| task_desc_info.input_format = input_format.empty() ? format_default : input_format; | |||||
| task_desc_info.input_shape = input_shape.empty() ? shape_default : input_shape; | |||||
| task_desc_info.input_data_type = input_data_type.empty() ? data_type_default : input_data_type; | |||||
| task_desc_info.output_format = output_format.empty() ? format_default : output_format; | task_desc_info.output_format = output_format.empty() ? format_default : output_format; | ||||
| task_desc_info.output_shape = output_shape.empty() ? shape_default : output_shape; | task_desc_info.output_shape = output_shape.empty() ? shape_default : output_shape; | ||||
| task_desc_info.output_data_type = output_data_type.empty() ? data_type_default : output_data_type; | task_desc_info.output_data_type = output_data_type.empty() ? data_type_default : output_data_type; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInputOutputInfo( | |||||
| const OpDescPtr &op, TaskDescInfo &task_desc_info) const { | |||||
| GetOpInputInfo(op, task_desc_info); | |||||
| GetOpOutputInfo(op, task_desc_info); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpPoint( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpPoint( | ||||
| std::string &fp_point, std::string &bp_point) { | std::string &fp_point, std::string &bp_point) { | ||||
| // Env or options mode, fp_point_/bp_point_ have initiliazed on profiling init | // Env or options mode, fp_point_/bp_point_ have initiliazed on profiling init | ||||
| @@ -111,6 +111,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { | |||||
| uint64_t GetProfilingModule(); | uint64_t GetProfilingModule(); | ||||
| void UpdateDeviceIdModuleMap(string prof_type, uint64_t module, const vector<int32_t> &device_list); | void UpdateDeviceIdModuleMap(string prof_type, uint64_t module, const vector<int32_t> &device_list); | ||||
| void UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module); | void UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module); | ||||
| void GetOpInputInfo(const OpDescPtr &op, TaskDescInfo &task_desc_info) const; | |||||
| void GetOpOutputInfo(const OpDescPtr &op, TaskDescInfo &task_desc_info) const; | |||||
| bool is_load_profiling_; | bool is_load_profiling_; | ||||
| bool is_execute_profiling_; | bool is_execute_profiling_; | ||||
| @@ -154,10 +154,7 @@ bool HcclTask::SetSecondaryStream() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream); | stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream); | ||||
| if (stream == nullptr) { | |||||
| GELOGE(FAILED, "MakeShared failed."); | |||||
| return false; | |||||
| } | |||||
| GE_RT_FALSE_CHECK_NOTNULL(stream); | |||||
| secondary_stream_vec[index] = stream; | secondary_stream_vec[index] = stream; | ||||
| } | } | ||||
| secondary_stream_list_.push_back(stream); | secondary_stream_list_.push_back(stream); | ||||
| @@ -854,7 +854,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
| op_desc->GetName().c_str()); | op_desc->GetName().c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_; | |||||
| OmgContext &omg_context = impl_->omg_context_; | |||||
| omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc); | omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc); | ||||
| if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) { | if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) { | ||||
| @@ -869,11 +869,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
| if (!HasShapeRange(inputs) && compile_flag == kFuzzBuildPattern) { | if (!HasShapeRange(inputs) && compile_flag == kFuzzBuildPattern) { | ||||
| fuzz_compile_flag = true; | fuzz_compile_flag = true; | ||||
| } | } | ||||
| if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag)) { | |||||
| REPORT_CALL_ERROR("E19999", "set ATTR_NAME_FUZZ_BUILD failed for %s.", op_desc->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Set][ATTR_NAME_FUZZ_BUILD] Failed to set attr for %s.", op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag); | |||||
| impl_->omg_context_.fuzz_compile_flag = fuzz_compile_flag; | impl_->omg_context_.fuzz_compile_flag = fuzz_compile_flag; | ||||
| // 1. Create ComputeGraph. | // 1. Create ComputeGraph. | ||||
| @@ -543,7 +543,6 @@ Status GraphMemoryAssigner::UpdateRefOpOffsetReverse(const NodePtr &node) { | |||||
| } | } | ||||
| Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { | Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { | ||||
| Status ret; | |||||
| // Stored nodes which need assign continuous input memory in `reverse topo order` | // Stored nodes which need assign continuous input memory in `reverse topo order` | ||||
| std::vector<NodePtr> nodes_stack; | std::vector<NodePtr> nodes_stack; | ||||
| std::map<NodePtr, uint32_t> node_2_continuous_type; | std::map<NodePtr, uint32_t> node_2_continuous_type; | ||||
| @@ -579,11 +578,8 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { | |||||
| if (continuous_output) { | if (continuous_output) { | ||||
| GE_CHK_STATUS_RET(GetNodeMemoryType(node, memory_type, "output"), | GE_CHK_STATUS_RET(GetNodeMemoryType(node, memory_type, "output"), | ||||
| "[Get][MemType]fail for node:%s", node->GetName().c_str()); | "[Get][MemType]fail for node:%s", node->GetName().c_str()); | ||||
| ret = AssignContinuousOutputMemory(node, memory_type, continuous_type); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(ret, "[Assign][Memory:Continuous:Ouput]fail for node:%s", node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| GE_CHK_STATUS_RET(AssignContinuousOutputMemory(node, memory_type, continuous_type), | |||||
| "[Assign][Memory:Continuous:Output]fail for node:%s", node->GetName().c_str()); | |||||
| } | } | ||||
| } | } | ||||
| // Assign continuous input memory in `reverse topo order` which stored before | // Assign continuous input memory in `reverse topo order` which stored before | ||||
| @@ -612,6 +608,61 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| Status GraphMemoryAssigner::SetMemOffset(const ge::NodePtr &node, const InDataAnchorPtr &in_data_anchor, | |||||
| bool reverse_refresh, int64_t &mem_offset, int64_t &continuous_mem_start) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| vector<int64_t> output_list_this = op_desc->GetOutputOffset(); | |||||
| if (output_list_this.empty()) { | |||||
| REPORT_INNER_ERROR("E19999", "No output offset in node :%s, not expected", | |||||
| node->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Get][OutputOffset] empty is invalid, node:%s", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| auto peer_op_desc = peer_out_data_anchor->GetOwnerNode()->GetOpDesc(); | |||||
| vector<int64_t> output_list = peer_op_desc->GetOutputOffset(); | |||||
| if (peer_out_data_anchor->GetIdx() >= static_cast<int>(output_list.size())) { | |||||
| std::string error = "peer node:" + FmtToStr(peer_op_desc->GetName()) + | |||||
| " anchor_index:" + FmtToStr(peer_out_data_anchor->GetIdx()) + | |||||
| " is out of range:" + FmtToStr(output_list.size()); | |||||
| GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // when continuous input has been allocated first input is beginning offset | |||||
| bool is_continuous_input_allocated = false; | |||||
| (void) ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, is_continuous_input_allocated); | |||||
| bool is_allocated_first_input = is_continuous_input_allocated && (in_data_anchor->GetIdx() == 0); | |||||
| if (is_allocated_first_input) { | |||||
| std::map<int32_t, int32_t> out2ins; | |||||
| GE_CHK_STATUS_RET(TryGetNodeRefIndexes(node, out2ins), "[Get][RefIndexes]fail for node: %s", | |||||
| node->GetName().c_str()); | |||||
| // output is beginning offset, set offset for input; only support this case now | |||||
| if ((out2ins.size() == 1) && (out2ins.begin()->second == 0) && (reverse_refresh)) { | |||||
| auto peer_output_offset = output_list.at(peer_out_data_anchor->GetIdx()); | |||||
| output_list.at(peer_out_data_anchor->GetIdx()) = output_list_this.at(out2ins.begin()->first); | |||||
| peer_op_desc->SetOutputOffset(output_list); | |||||
| GELOGI("[Update][Offset]Node %s out %d ref in %d input node %s, use output offset %ld update %ld", | |||||
| node->GetName().c_str(), out2ins.begin()->first, out2ins.begin()->second, | |||||
| peer_op_desc->GetName().c_str(), output_list_this.at(out2ins.begin()->first), peer_output_offset); | |||||
| } else { | |||||
| GELOGD("Node %s out %d ref in %d input node %s with total ref numbers %zu.", node->GetName().c_str(), | |||||
| out2ins.begin()->first, out2ins.begin()->second, peer_op_desc->GetName().c_str(), out2ins.size()); | |||||
| } | |||||
| // first input is beginning offset | |||||
| mem_offset = output_list.at(peer_out_data_anchor->GetIdx()); | |||||
| continuous_mem_start = output_list.at(peer_out_data_anchor->GetIdx()); | |||||
| } else { | |||||
| // set offset for input | |||||
| output_list.at(peer_out_data_anchor->GetIdx()) = mem_offset; | |||||
| peer_op_desc->SetOutputOffset(output_list); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, int64_t &continuous_mem_start, | Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, int64_t &continuous_mem_start, | ||||
| int64_t &continuous_mem_size, int64_t memory_type, uint32_t continuous_type, bool reverse_refresh) { | int64_t &continuous_mem_size, int64_t memory_type, uint32_t continuous_type, bool reverse_refresh) { | ||||
| GELOGI("[Assign][Memory:Input:Continuous]start for Current node %s", node->GetName().c_str()); | GELOGI("[Assign][Memory:Input:Continuous]start for Current node %s", node->GetName().c_str()); | ||||
| @@ -631,13 +682,6 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
| bool is_continuous_input_allocated = false; | bool is_continuous_input_allocated = false; | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| vector<int64_t> output_list_this = op_desc->GetOutputOffset(); | |||||
| if (output_list_this.empty()) { | |||||
| REPORT_INNER_ERROR("E19999", "No output offset in node :%s, not expected", | |||||
| node->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Get][OutputOffset] empty is invalid, node:%s", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void) ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, is_continuous_input_allocated); | (void) ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, is_continuous_input_allocated); | ||||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| GE_IF_BOOL_EXEC(in_data_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(in_data_anchor == nullptr, continue); | ||||
| @@ -669,45 +713,12 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| bool is_nopadding = ((continuous_type & kTypeInputNoPadding) != 0) || lx_fusion; | |||||
| vector<int64_t> output_list = peer_op_desc->GetOutputOffset(); | |||||
| if (peer_out_data_anchor->GetIdx() >= static_cast<int>(output_list.size())) { | |||||
| std::string error = "peer node:" + FmtToStr(peer_op_desc->GetName()) + | |||||
| " anchor_index:" + FmtToStr(peer_out_data_anchor->GetIdx()) + | |||||
| " is out of range:" + FmtToStr(output_list.size()); | |||||
| GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | |||||
| if (SetMemOffset(node, in_data_anchor, reverse_refresh, mem_offset, continuous_mem_start) != ge::SUCCESS) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // when continuous input has been allocated first input is beginning offset | |||||
| bool is_allocated_first_input = is_continuous_input_allocated && (in_data_anchor->GetIdx() == 0); | |||||
| if (is_allocated_first_input) { | |||||
| std::map<int32_t, int32_t> out2ins; | |||||
| GE_CHK_STATUS_RET(TryGetNodeRefIndexes(node, out2ins), "[Get][RefIndexes]fail for node: %s", | |||||
| node->GetName().c_str()); | |||||
| // output is beginning offset, set offset for input; only support this case now | |||||
| if ((out2ins.size() == 1) && (out2ins.begin()->second == 0) && (reverse_refresh)) { | |||||
| auto peer_output_offset = output_list.at(peer_out_data_anchor->GetIdx()); | |||||
| output_list.at(peer_out_data_anchor->GetIdx()) = output_list_this.at(out2ins.begin()->first); | |||||
| peer_op_desc->SetOutputOffset(output_list); | |||||
| GELOGI("[Update][Offset]Node %s out %d ref in %d input node %s, use output offset %ld update %ld", | |||||
| node->GetName().c_str(), out2ins.begin()->first, out2ins.begin()->second, | |||||
| peer_op_desc->GetName().c_str(), output_list_this.at(out2ins.begin()->first), peer_output_offset); | |||||
| } else { | |||||
| GELOGD("Node %s out %d ref in %d input node %s with total ref numbers %zu.", node->GetName().c_str(), | |||||
| out2ins.begin()->first, out2ins.begin()->second, peer_op_desc->GetName().c_str(), out2ins.size()); | |||||
| } | |||||
| // first input is beginning offset | |||||
| mem_offset = output_list.at(peer_out_data_anchor->GetIdx()); | |||||
| continuous_mem_start = output_list.at(peer_out_data_anchor->GetIdx()); | |||||
| } else { | |||||
| // set offset for input | |||||
| output_list.at(peer_out_data_anchor->GetIdx()) = mem_offset; | |||||
| peer_op_desc->SetOutputOffset(output_list); | |||||
| } | |||||
| int64_t align_size = tensor_desc_size; | int64_t align_size = tensor_desc_size; | ||||
| bool is_nopadding = ((continuous_type & kTypeInputNoPadding) != 0) || lx_fusion; | |||||
| if (is_nopadding) { | if (is_nopadding) { | ||||
| mem_offset += nopadding_size; | mem_offset += nopadding_size; | ||||
| extra_memory_size += (tensor_desc_size - nopadding_size); | extra_memory_size += (tensor_desc_size - nopadding_size); | ||||
| @@ -719,7 +730,7 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
| extra_memory_size = MEM_ALIGN_SIZE; | extra_memory_size = MEM_ALIGN_SIZE; | ||||
| real_size = tensor_desc_size; | real_size = tensor_desc_size; | ||||
| } | } | ||||
| vector<int64_t> output_list = peer_op_desc->GetOutputOffset(); | |||||
| GELOGI("[IMAS]Continuous input : Set %s name[%s] optype[%s] output[%d] offset to [%zu] stream_id[%ld] memtype[%ld] " | GELOGI("[IMAS]Continuous input : Set %s name[%s] optype[%s] output[%d] offset to [%zu] stream_id[%ld] memtype[%ld] " | ||||
| "size[%zu] realsize[%ld] nopadding size[%d]", node->GetOwnerComputeGraph()->GetName().c_str(), | "size[%zu] realsize[%ld] nopadding size[%d]", node->GetOwnerComputeGraph()->GetName().c_str(), | ||||
| peer_op_desc->GetName().c_str(), node->GetType().c_str(), peer_out_data_anchor->GetIdx(), | peer_op_desc->GetName().c_str(), node->GetType().c_str(), peer_out_data_anchor->GetIdx(), | ||||
| @@ -146,6 +146,9 @@ class GraphMemoryAssigner { | |||||
| ge::Status FilterAtomicNodesForMemoryAssign(map<string, map<NodePtr, vector<NodePtr>>> &normal_atomic_nodes_map, | ge::Status FilterAtomicNodesForMemoryAssign(map<string, map<NodePtr, vector<NodePtr>>> &normal_atomic_nodes_map, | ||||
| map<string, vector<NodePtr>> &connecting_output_atomic_nodes); | map<string, vector<NodePtr>> &connecting_output_atomic_nodes); | ||||
| Status SetMemOffset(const ge::NodePtr &node, const InDataAnchorPtr &in_data_anchor, bool reverse_refresh, | |||||
| int64_t &mem_offset, int64_t &continuous_mem_start); | |||||
| ge::Status AssignContinuousInputMemory(const ge::NodePtr &node, int64_t &continuous_mem_start, | ge::Status AssignContinuousInputMemory(const ge::NodePtr &node, int64_t &continuous_mem_start, | ||||
| int64_t &continuous_mem_size, int64_t memory_type, uint32_t continuous_type, | int64_t &continuous_mem_size, int64_t memory_type, uint32_t continuous_type, | ||||
| bool reverse_refresh = false); | bool reverse_refresh = false); | ||||
| @@ -1212,7 +1212,8 @@ Status StreamAllocator::SetActiveStreamsForLoop() { | |||||
| for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | ||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
| bool is_loop_active = false; | bool is_loop_active = false; | ||||
| if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active) && is_loop_active) { | |||||
| (void)AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active); | |||||
| if (is_loop_active) { | |||||
| vector<string> activated_label_list; | vector<string> activated_label_list; | ||||
| NodePtr pre_switch_node = FindSwitchNodeBeforeLoopActiveNode(node); | NodePtr pre_switch_node = FindSwitchNodeBeforeLoopActiveNode(node); | ||||
| @@ -1668,42 +1668,23 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||||
| }; | }; | ||||
| GE_MAKE_GUARD(release, callback); | GE_MAKE_GUARD(release, callback); | ||||
| // malloc sysOpInfoList in SysOpCheckInfo | // malloc sysOpInfoList in SysOpCheckInfo | ||||
| status = rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| GE_CHK_RT_RET(rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM)); | |||||
| allocated_mem.push_back(d_req_op_list); | allocated_mem.push_back(d_req_op_list); | ||||
| // malloc sysOpInfoList in SysOpCheckResp | // malloc sysOpInfoList in SysOpCheckResp | ||||
| status = rtMalloc(&d_res_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| GE_CHK_RT_RET(rtMalloc(&d_res_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM)); | |||||
| allocated_mem.push_back(d_res_op_list); | allocated_mem.push_back(d_res_op_list); | ||||
| // malloc returnCodeList in SysOpCheckResp | // malloc returnCodeList in SysOpCheckResp | ||||
| status = rtMalloc(&d_ret_code_list, op_nums * sizeof(ReturnCode), RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(ReturnCode), status); | |||||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(ReturnCode), status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| GE_CHK_RT_RET(rtMalloc(&d_ret_code_list, op_nums * sizeof(ReturnCode), RT_MEMORY_HBM)); | |||||
| allocated_mem.push_back(d_ret_code_list); | allocated_mem.push_back(d_ret_code_list); | ||||
| for (const auto &op_type : aicpu_optype_list) { | for (const auto &op_type : aicpu_optype_list) { | ||||
| SysOpInfo op_info; | SysOpInfo op_info; | ||||
| // malloc op_type name in SysOpInfo | // malloc op_type name in SysOpInfo | ||||
| void *d_op_type_name = nullptr; | void *d_op_type_name = nullptr; | ||||
| status = rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%lu, ret = 0x%X", op_type.length(), status); | |||||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%lu, ret = 0x%X", op_type.length(), status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| GE_CHK_RT_RET(rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM)); | |||||
| allocated_mem.push_back(d_op_type_name); | allocated_mem.push_back(d_op_type_name); | ||||
| GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.length(), op_type.c_str(), op_type.length(), RT_MEMCPY_HOST_TO_DEVICE)); | GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.length(), op_type.c_str(), op_type.length(), RT_MEMCPY_HOST_TO_DEVICE)); | ||||
| op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name)); | op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name)); | ||||
| @@ -1716,12 +1697,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||||
| SysOpInfo op_info; | SysOpInfo op_info; | ||||
| // malloc op_type name in SysOpInfo | // malloc op_type name in SysOpInfo | ||||
| void *d_op_type_name = nullptr; | void *d_op_type_name = nullptr; | ||||
| status = rtMalloc(&d_op_type_name, op_type.size(), RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%lu, ret = 0x%X", op_type.length(), status); | |||||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%lu, ret = 0x%X", op_type.size(), status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| GE_CHK_RT_RET(rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM)); | |||||
| allocated_mem.push_back(d_op_type_name); | allocated_mem.push_back(d_op_type_name); | ||||
| GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.size(), op_type.c_str(), op_type.size(), RT_MEMCPY_HOST_TO_DEVICE)); | GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.size(), op_type.c_str(), op_type.size(), RT_MEMCPY_HOST_TO_DEVICE)); | ||||
| op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name)); | op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name)); | ||||
| @@ -1745,12 +1722,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||||
| op_check_info_res.sysOpInfoList = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_res_op_list)); | op_check_info_res.sysOpInfoList = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_res_op_list)); | ||||
| uint32_t args_size = sizeof(SysOpCheckInfo) + sizeof(SysOpCheckResp); | uint32_t args_size = sizeof(SysOpCheckInfo) + sizeof(SysOpCheckResp); | ||||
| status = rtMalloc(&args, args_size, RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%u, ret = 0x%X", args_size, status); | |||||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%u, ret = 0x%X", args_size, status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| GE_CHK_RT_RET(rtMalloc(&args, args_size, RT_MEMORY_HBM)); | |||||
| allocated_mem.push_back(args); | allocated_mem.push_back(args); | ||||
| GE_CHK_RT(rtMemcpy(args, sizeof(SysOpCheckInfo), reinterpret_cast<void *>(&op_check_info_req), sizeof(SysOpCheckInfo), | GE_CHK_RT(rtMemcpy(args, sizeof(SysOpCheckInfo), reinterpret_cast<void *>(&op_check_info_req), sizeof(SysOpCheckInfo), | ||||
| RT_MEMCPY_HOST_TO_DEVICE)); | RT_MEMCPY_HOST_TO_DEVICE)); | ||||
| @@ -3533,9 +3533,8 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); | GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); | ||||
| if ((options_.build_mode == BUILD_MODE_TUNING) && | |||||
| (options_.build_step == BUILD_STEP_BEFORE_UB_MATCH || options_.build_step == BUILD_STEP_AFTER_BUILDER || | |||||
| options_.build_step == BUILD_STEP_AFTER_BUILDER_SUB)) { | |||||
| std::set<string> build_steps = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB}; | |||||
| if ((options_.build_mode == BUILD_MODE_TUNING) && (build_steps.count(options_.build_step) > 0)) { | |||||
| GE_TIMESTAMP_START(ConvertGraphToFile); | GE_TIMESTAMP_START(ConvertGraphToFile); | ||||
| std::string tuning_path; | std::string tuning_path; | ||||
| (void) GetContext().GetOption(TUNING_PATH, tuning_path); | (void) GetContext().GetOption(TUNING_PATH, tuning_path); | ||||
| @@ -743,12 +743,10 @@ Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| // ignore data / netoutput of subgraph | // ignore data / netoutput of subgraph | ||||
| if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | |||||
| continue; | |||||
| } | |||||
| if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | |||||
| if (IsSubgraphInputNode(node) || IsSubgraphOutputNode(node)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| bool identity_reserved = false; | bool identity_reserved = false; | ||||
| AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CANNOT_BE_DELETED, identity_reserved); | AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CANNOT_BE_DELETED, identity_reserved); | ||||
| if (identity_reserved) { | if (identity_reserved) { | ||||
| @@ -366,11 +366,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr | |||||
| // link input -> end | // link input -> end | ||||
| string end_name = kEndType + std::to_string(graph_info_.num_of_pld_end_); | string end_name = kEndType + std::to_string(graph_info_.num_of_pld_end_); | ||||
| auto end_op_desc = MakeShared<OpDesc>(end_graph->GetName() + "_" + end_name, END); | auto end_op_desc = MakeShared<OpDesc>(end_graph->GetName() + "_" + end_name, END); | ||||
| if (end_op_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "New Memory for OpDesc failed."); | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[New][Memory] for OpDesc failed, pld_op_desc is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| GE_CHECK_NOTNULL(end_op_desc); | |||||
| GE_IF_BOOL_EXEC(!AttrUtils::SetInt(end_op_desc, "peerIndex", graph_info_.num_of_pld_end_), | GE_IF_BOOL_EXEC(!AttrUtils::SetInt(end_op_desc, "peerIndex", graph_info_.num_of_pld_end_), | ||||
| GELOGW("SetInt peerIndex failed");) | GELOGW("SetInt peerIndex failed");) | ||||
| GE_IF_BOOL_EXEC(!AttrUtils::SetStr(end_op_desc, "parentOpType", dst_node->GetType()), | GE_IF_BOOL_EXEC(!AttrUtils::SetStr(end_op_desc, "parentOpType", dst_node->GetType()), | ||||
| @@ -429,11 +426,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr | |||||
| int64_t node_id = src_node_opdesc->GetId(); | int64_t node_id = src_node_opdesc->GetId(); | ||||
| const string pld_name = kPlaceHolderType + std::to_string(graph_info_.num_of_pld_end_); | const string pld_name = kPlaceHolderType + std::to_string(graph_info_.num_of_pld_end_); | ||||
| auto pld_op_desc = MakeShared<OpDesc>(pld_graph->GetName() + "_" + pld_name, PLACEHOLDER); | auto pld_op_desc = MakeShared<OpDesc>(pld_graph->GetName() + "_" + pld_name, PLACEHOLDER); | ||||
| if (pld_op_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "New Memory for OpDesc failed."); | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[New][Memory] for OpDesc failed."); | |||||
| return FAILED; | |||||
| } | |||||
| GE_CHECK_NOTNULL(pld_op_desc); | |||||
| GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "peerIndex", graph_info_.num_of_pld_end_), | GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "peerIndex", graph_info_.num_of_pld_end_), | ||||
| GELOGW("SetInt peerIndex failed");) | GELOGW("SetInt peerIndex failed");) | ||||
| GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, "_peerNodeName", new_end_node->GetName()), | GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, "_peerNodeName", new_end_node->GetName()), | ||||
| @@ -199,6 +199,24 @@ void ClearOption(NamesToPass names_to_pass) { | |||||
| name_to_pass.second->ClearOptions(); | name_to_pass.second->ClearOptions(); | ||||
| } | } | ||||
| } | } | ||||
| bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) { | |||||
| if (node == nullptr) { | |||||
| GELOGW("node is null"); | |||||
| return false; | |||||
| } | |||||
| if (during_pass_node_set.nodes_deleted.count(node) > 0) { | |||||
| GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| if (during_pass_node_set.nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
| node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) { | Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) { | ||||
| @@ -277,17 +295,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
| nodes.pop_front(); | nodes.pop_front(); | ||||
| (void)during_pass_node_set.nodes_re_pass.erase(node); | (void)during_pass_node_set.nodes_re_pass.erase(node); | ||||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGW("node is null"); continue); | |||||
| if (during_pass_node_set.nodes_deleted.count(node) > 0) { | |||||
| GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
| if (!CheckNode(node, during_pass_node_set)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (during_pass_node_set.nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
| node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); | AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); | ||||
| auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | ||||
| @@ -70,9 +70,9 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||||
| } | } | ||||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | ||||
| bool need_cycle_flag = false; | bool need_cycle_flag = false; | ||||
| bool is_found = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag); | |||||
| (void)AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag); | |||||
| // small cycle flag is need_stream_cycle_event == true | // small cycle flag is need_stream_cycle_event == true | ||||
| if (is_found && need_cycle_flag) { | |||||
| if (need_cycle_flag) { | |||||
| Status ret = AddSpecialNodeIteratorCtrl(compute_graph, node); | Status ret = AddSpecialNodeIteratorCtrl(compute_graph, node); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[Add][SpecialNodeIteratorCtrl] failed, node:%s, graph:%s.", | GELOGE(ret, "[Add][SpecialNodeIteratorCtrl] failed, node:%s, graph:%s.", | ||||
| @@ -166,26 +166,7 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra | |||||
| GELOGD("%s, index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, node->GetName().c_str()); | GELOGD("%s, index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, node->GetName().c_str()); | ||||
| } else if ((node->GetType() == CONSTANT) && (node->GetOutDataAnchor(kZeroIndex) != nullptr)) { | } else if ((node->GetType() == CONSTANT) && (node->GetOutDataAnchor(kZeroIndex) != nullptr)) { | ||||
| set<string> peer_name_list; | set<string> peer_name_list; | ||||
| const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex); | |||||
| for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
| const auto &peer_node = in_anchor->GetOwnerNode(); | |||||
| // Trim subgraph node name prefix. | |||||
| string node_full_name = peer_node->GetName(); | |||||
| size_t pos = node_full_name.find(kMbatchNodeNameMark); | |||||
| if (pos == string::npos) { | |||||
| GELOGI("Can not find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); | |||||
| continue; | |||||
| } | |||||
| string fixed_name = node_full_name.substr(0, pos); | |||||
| pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length()); | |||||
| if (pos != string::npos) { | |||||
| fixed_name += node_full_name.substr(pos); | |||||
| } | |||||
| peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx())); | |||||
| } | |||||
| GetPeerNameList(node, peer_name_list); | |||||
| if (peer_name_list.empty()) { | if (peer_name_list.empty()) { | ||||
| GELOGI("%s, Const: %s, no data output", subgraph->GetName().c_str(), node->GetName().c_str()); | GELOGI("%s, Const: %s, no data output", subgraph->GetName().c_str(), node->GetName().c_str()); | ||||
| const auto in_all_nodes = node->GetInAllNodes(); | const auto in_all_nodes = node->GetInAllNodes(); | ||||
| @@ -216,6 +197,28 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void SubgraphConstMigrationPass::GetPeerNameList(const NodePtr &node, set<string> &peer_name_list) { | |||||
| const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex); | |||||
| for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
| const auto &peer_node = in_anchor->GetOwnerNode(); | |||||
| // Trim subgraph node name prefix. | |||||
| string node_full_name = peer_node->GetName(); | |||||
| size_t pos = node_full_name.find(kMbatchNodeNameMark); | |||||
| if (pos == string::npos) { | |||||
| GELOGI("Can not find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); | |||||
| continue; | |||||
| } | |||||
| string fixed_name = node_full_name.substr(0, pos); | |||||
| pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length()); | |||||
| if (pos != string::npos) { | |||||
| fixed_name += node_full_name.substr(pos); | |||||
| } | |||||
| peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx())); | |||||
| } | |||||
| } | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Get parent_index for Const node migration. | /// @brief Get parent_index for Const node migration. | ||||
| @@ -133,6 +133,8 @@ class SubgraphConstMigrationPass : public GraphPass { | |||||
| /// | /// | ||||
| Status AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, | Status AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, | ||||
| const NodePtr &const_node, uint32_t parent_index); | const NodePtr &const_node, uint32_t parent_index); | ||||
| void GetPeerNameList(const NodePtr &node, set<string> &peer_name_list); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_COMMON_SUBGRAPH_CONST_MIGRATION_H_ | #endif // GE_COMMON_SUBGRAPH_CONST_MIGRATION_H_ | ||||
| @@ -64,16 +64,19 @@ std::string TransOpBreadthFusionPass::GetNodeId(const int anchor_index, const No | |||||
| GE_IF_BOOL_EXEC(node == nullptr || node->GetOpDesc() == nullptr, | GE_IF_BOOL_EXEC(node == nullptr || node->GetOpDesc() == nullptr, | ||||
| REPORT_INNER_ERROR("E19999", "Param node or its op_desc is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param node or its op_desc is nullptr, check invalid"); | ||||
| GELOGE(FAILED, "[Check][Param] Param node or its op_desc is nullptr"); return ""); | GELOGE(FAILED, "[Check][Param] Param node or its op_desc is nullptr"); return ""); | ||||
| std::set<std::string> trans_shapes = { RESHAPE, EXPANDDIMS, SQUEEZE }; | |||||
| std::set<std::string> trans_shape_and_format = { TRANSPOSE, TRANSPOSED, EXPANDDIMS }; | |||||
| if (node->GetType() == CAST) { | if (node->GetType() == CAST) { | ||||
| trans_data_type = true; | trans_data_type = true; | ||||
| } else if (node->GetType() == TRANSPOSE || node->GetType() == TRANSPOSED || node->GetType() == EXPANDDIMS) { | |||||
| } else if (trans_shape_and_format.count(node->GetType()) > 0) { | |||||
| trans_format = true; | trans_format = true; | ||||
| trans_shape = true; | trans_shape = true; | ||||
| } else if (node->GetType() == TRANSDATA) { | } else if (node->GetType() == TRANSDATA) { | ||||
| trans_data_type = true; | trans_data_type = true; | ||||
| trans_format = true; | trans_format = true; | ||||
| trans_shape = true; | trans_shape = true; | ||||
| } else if (node->GetType() == RESHAPE || node->GetType() == EXPANDDIMS || node->GetType() == SQUEEZE) { | |||||
| } else if (trans_shapes.count(node->GetType()) > 0) { | |||||
| trans_shape = true; | trans_shape = true; | ||||
| } else if (node->GetType() == REFORMAT) { | } else if (node->GetType() == REFORMAT) { | ||||
| trans_format = true; | trans_format = true; | ||||
| @@ -1423,6 +1423,25 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag) { | |||||
| auto format = desc.GetFormat(); | |||||
| auto origin_format = desc.GetOriginFormat(); | |||||
| bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); | |||||
| if (need_check_internal_format) { | |||||
| bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); | |||||
| if (is_internal) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {"Input format[" + | |||||
| TypeUtils::FormatToSerialString(format) + "] or origin_format[" + | |||||
| TypeUtils::FormatToSerialString(origin_format) + "]", | |||||
| "it is not support"}); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Input format %s or origin_format %s is not support.", | |||||
| TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::FormatToSerialString(origin_format).c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, | Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, | ||||
| const std::map<string, string> &graph_option) { | const std::map<string, string> &graph_option) { | ||||
| // Get shape range of input in dynamic_execute mode | // Get shape range of input in dynamic_execute mode | ||||
| @@ -1454,24 +1473,13 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, | |||||
| continue; | continue; | ||||
| } | } | ||||
| GeTensorDesc desc(user_input[index].GetTensorDesc()); | GeTensorDesc desc(user_input[index].GetTensorDesc()); | ||||
| auto format = desc.GetFormat(); | |||||
| auto origin_format = desc.GetOriginFormat(); | |||||
| // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. | // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. | ||||
| auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | ||||
| bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); | |||||
| if (need_check_internal_format) { | |||||
| bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); | |||||
| if (is_internal) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, | |||||
| {"Input format[" + TypeUtils::FormatToSerialString(format) + "] or origin_format[" + | |||||
| TypeUtils::FormatToSerialString(origin_format) + "]", "it is not support"}); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Input format %s or origin_format %s is not support.", | |||||
| TypeUtils::FormatToSerialString(format).c_str(), | |||||
| TypeUtils::FormatToSerialString(origin_format).c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| ret = CheckInternalFormat(input_node, desc, tune_flag); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str()); | |||||
| return ret; | |||||
| } | } | ||||
| auto data_type = desc.GetDataType(); | auto data_type = desc.GetDataType(); | ||||
| uint32_t length = 1; | uint32_t length = 1; | ||||
| bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | ||||
| @@ -63,6 +63,7 @@ class GraphPrepare { | |||||
| Status CheckRefOp(); | Status CheckRefOp(); | ||||
| Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); | Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); | ||||
| Status AdjustDataOpOutput(const NodePtr &node); | Status AdjustDataOpOutput(const NodePtr &node); | ||||
| Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag); | |||||
| Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | ||||
| Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | ||||
| Status CheckConstOp(); | Status CheckConstOp(); | ||||
| @@ -71,15 +71,13 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||||
| GELOGW("The number of input for slice must be %zu.", kSliceInputSize); | GELOGW("The number of input for slice must be %zu.", kSliceInputSize); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| ConstGeTensorPtr x_ = input[kSliceInputIndexX]; | ConstGeTensorPtr x_ = input[kSliceInputIndexX]; | ||||
| ConstGeTensorPtr begin = input[kSliceInputIndexBegin]; | ConstGeTensorPtr begin = input[kSliceInputIndexBegin]; | ||||
| ConstGeTensorPtr size = input[kSliceInputIndexSize]; | ConstGeTensorPtr size = input[kSliceInputIndexSize]; | ||||
| if (x_ == nullptr || begin == nullptr || size == nullptr) { | |||||
| GELOGW("input tensor is nullptr."); | |||||
| return NOT_CHANGED; | |||||
| Status ret = CheckInput(x_, begin, size); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | } | ||||
| // data type in input_x | // data type in input_x | ||||
| auto data_type = x_->GetTensorDesc().GetDataType(); | auto data_type = x_->GetTensorDesc().GetDataType(); | ||||
| // check supported | // check supported | ||||
| @@ -92,11 +90,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||||
| if (!is_success) { | if (!is_success) { | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| // check data type of begin and size | |||||
| if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { | |||||
| GELOGW("Data type of begin and size for slice are not DT_INT32."); | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| void *data = reinterpret_cast<void *>(const_cast<uint8_t *>(x_->GetData().data())); | void *data = reinterpret_cast<void *>(const_cast<uint8_t *>(x_->GetData().data())); | ||||
| int32_t *begin_data = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(begin->GetData().GetData())); | int32_t *begin_data = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(begin->GetData().GetData())); | ||||
| @@ -145,7 +139,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| Status ret = CheckOutputDims(output_dims, attr); | |||||
| ret = CheckOutputDims(output_dims, attr); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -161,6 +155,20 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SliceKernel::CheckInput(const ConstGeTensorPtr &x_, const ConstGeTensorPtr &begin, | |||||
| const ConstGeTensorPtr &size) { | |||||
| if (x_ == nullptr || begin == nullptr || size == nullptr) { | |||||
| GELOGW("input tensor is nullptr."); | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| // check data type of begin and size | |||||
| if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { | |||||
| GELOGW("Data type of begin and size for slice are not DT_INT32."); | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SliceKernel::CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr) { | Status SliceKernel::CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr) { | ||||
| // check dim not all less than 0 | // check dim not all less than 0 | ||||
| for (auto dim : output_dims) { | for (auto dim : output_dims) { | ||||
| @@ -28,6 +28,7 @@ class SliceKernel : public Kernel { | |||||
| vector<GeTensorPtr> &v_output) override; | vector<GeTensorPtr> &v_output) override; | ||||
| Status CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr); | Status CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr); | ||||
| Status CheckInput(const ConstGeTensorPtr &x_, const ConstGeTensorPtr &begin, const ConstGeTensorPtr &size); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/manager/util/hcom_util.h" | #include "graph/manager/util/hcom_util.h" | ||||
| #include "graph/runtime_inference_context.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| #include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
| @@ -95,8 +94,8 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| } | } | ||||
| op_info.dataType = iter->second; | op_info.dataType = iter->second; | ||||
| HcclReduceOp op_type = HCCL_REDUCE_SUM; | HcclReduceOp op_type = HCCL_REDUCE_SUM; | ||||
| if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMREDUCESCATTER || | |||||
| op_desc->GetType() == HVDCALLBACKALLREDUCE || op_desc->GetType() == HCOMREDUCE) { | |||||
| std::set<std::string> hccl_types = { HCOMALLREDUCE, HCOMREDUCESCATTER, HVDCALLBACKALLREDUCE, HCOMREDUCE }; | |||||
| if (hccl_types.count(op_desc->GetType()) > 0) { | |||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | ||||
| "[Get][HcclOperationType] failed for %s type:%s", op_desc->GetName().c_str(), | "[Get][HcclOperationType] failed for %s type:%s", op_desc->GetName().c_str(), | ||||
| op_desc->GetType().c_str()); | op_desc->GetType().c_str()); | ||||
| @@ -177,69 +176,15 @@ Status RdmaNodeTask::Init(TaskContext &context) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||||
| RuntimeInferenceContext *ctx = nullptr; | |||||
| GE_CHK_STATUS_RET( | |||||
| RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||||
| ge::Tensor remote_tensor; | |||||
| GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | |||||
| auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | |||||
| if (data == nullptr) { | |||||
| if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||||
| GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
| context.GetNodeItem().NodeType().c_str()); | |||||
| GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
| context.GetNodeItem().NodeType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | |||||
| if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | |||||
| REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" | |||||
| "and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
| dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
| context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," | |||||
| "number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
| dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
| context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { | |||||
| size_t remote_size = 0; | |||||
| for (auto idx = 0; idx < dims.front(); ++idx) { | |||||
| FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | |||||
| auto line_idx = idx * kVarTableRowCnt; | |||||
| remote_size += data[line_idx + kVarTableIdxLen]; | |||||
| } | |||||
| auto allocator = NpuMemoryAllocator::GetAllocator(); | |||||
| GE_CHECK_NOTNULL(allocator); | |||||
| AllocationAttr attr; | |||||
| attr.SetMemType(RDMA_HBM); | |||||
| for (auto i = 0; i < context.NumOutputs(); ++i) { | |||||
| GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); | |||||
| auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | |||||
| GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | |||||
| } | |||||
| } else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { | |||||
| AllocationAttr attr; | |||||
| attr.SetMemType(RDMA_HBM); | |||||
| GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) | |||||
| } | |||||
| TensorValue *tv; | |||||
| Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, | |||||
| vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||||
| TensorValue *tv = nullptr; | |||||
| if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { | if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { | ||||
| tv = context.MutableOutput(local_index_); | tv = context.MutableOutput(local_index_); | ||||
| } else { | } else { | ||||
| tv = context.MutableInput(local_index_); | tv = context.MutableInput(local_index_); | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(tv); | GE_CHECK_NOTNULL(tv); | ||||
| auto row_num = dims.front(); | |||||
| addr_infos.resize(row_num); | addr_infos.resize(row_num); | ||||
| if (skip_flag_) { | if (skip_flag_) { | ||||
| int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); | int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); | ||||
| @@ -294,6 +239,65 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||||
| RuntimeInferenceContext *ctx = nullptr; | |||||
| GE_CHK_STATUS_RET( | |||||
| RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||||
| ge::Tensor remote_tensor; | |||||
| GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | |||||
| auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | |||||
| if (data == nullptr) { | |||||
| if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||||
| GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
| context.GetNodeItem().NodeType().c_str()); | |||||
| GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
| context.GetNodeItem().NodeType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | |||||
| if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | |||||
| REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" | |||||
| "and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
| dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
| context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," | |||||
| "number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
| dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
| context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { | |||||
| size_t remote_size = 0; | |||||
| for (auto idx = 0; idx < dims.front(); ++idx) { | |||||
| FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | |||||
| auto line_idx = idx * kVarTableRowCnt; | |||||
| remote_size += data[line_idx + kVarTableIdxLen]; | |||||
| } | |||||
| auto allocator = NpuMemoryAllocator::GetAllocator(); | |||||
| GE_CHECK_NOTNULL(allocator); | |||||
| AllocationAttr attr; | |||||
| attr.SetMemType(RDMA_HBM); | |||||
| for (auto i = 0; i < context.NumOutputs(); ++i) { | |||||
| GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); | |||||
| auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | |||||
| GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | |||||
| } | |||||
| } else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { | |||||
| AllocationAttr attr; | |||||
| attr.SetMemType(RDMA_HBM); | |||||
| GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) | |||||
| } | |||||
| auto row_num = dims.front(); | |||||
| return SetAddrInfo(context, ctx, data, row_num, addr_infos); | |||||
| } | |||||
| Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | ||||
| GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); | GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); | ||||
| auto HcomExecEnqueueRemoteAccess = | auto HcomExecEnqueueRemoteAccess = | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define HYBRID_HCCL_NODE_EXECUTOR_H_ | #define HYBRID_HCCL_NODE_EXECUTOR_H_ | ||||
| #include "common/opskernel/ge_task_info.h" | #include "common/opskernel/ge_task_info.h" | ||||
| #include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
| #include "graph/runtime_inference_context.h" | |||||
| #include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| @@ -53,6 +54,8 @@ class RdmaNodeTask : public NodeTask { | |||||
| Status Init(TaskContext &context) override; | Status Init(TaskContext &context) override; | ||||
| private: | private: | ||||
| Status SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, | |||||
| vector<HcomRemoteAccessAddrInfo> &addr_infos); | |||||
| Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); | Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); | ||||
| std::pair<int64_t, int64_t> remote_index_; | std::pair<int64_t, int64_t> remote_index_; | ||||
| std::pair<int64_t, int64_t> offset_index_; | std::pair<int64_t, int64_t> offset_index_; | ||||
| @@ -272,8 +272,7 @@ class Impl { | |||||
| graphStatus Init(const Graph &graph, const std::map<std::string, std::string> &options); | graphStatus Init(const Graph &graph, const std::map<std::string, std::string> &options); | ||||
| graphStatus BuildModel(const Graph &graph, const std::map<std::string, std::string> &options, | graphStatus BuildModel(const Graph &graph, const std::map<std::string, std::string> &options, | ||||
| ModelBufferData &ge_models); | ModelBufferData &ge_models); | ||||
| graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, | |||||
| bool is_dynamic_input); | |||||
| graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, bool is_dynamic_input); | |||||
| graphStatus GetInputShapeRange(const string &input_shape_range, | graphStatus GetInputShapeRange(const string &input_shape_range, | ||||
| std::map<string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map, | std::map<string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map, | ||||
| std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map); | std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map); | ||||
| @@ -283,6 +282,7 @@ class Impl { | |||||
| void SetRtSocVersion(); | void SetRtSocVersion(); | ||||
| void UpdateThreadContext(); | void UpdateThreadContext(); | ||||
| void LoadOpsProto(); | void LoadOpsProto(); | ||||
| std::string GetParam(const std::string ¶m); | |||||
| public: | public: | ||||
| ge::GeGenerator generator_; | ge::GeGenerator generator_; | ||||
| std::map<std::string, std::string> options_; | std::map<std::string, std::string> options_; | ||||
| @@ -512,6 +512,10 @@ graphStatus Impl::CheckOptions(const std::map<std::string, std::string> &options | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| std::string Impl::GetParam(const std::string ¶m) { | |||||
| return options_.find(param) == options_.end() ? "" : options_[param]; | |||||
| } | |||||
| graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::string> &options) { | graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::string> &options) { | ||||
| // 1. check options | // 1. check options | ||||
| graphStatus ret = CheckOptions(options); | graphStatus ret = CheckOptions(options); | ||||
| @@ -533,20 +537,13 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(ge::CheckLogParamValidAndSetLogLevel(log) == 0, GRAPH_PARAM_INVALID); | GE_CHK_BOOL_RET_STATUS_NOLOG(ge::CheckLogParamValidAndSetLogLevel(log) == 0, GRAPH_PARAM_INVALID); | ||||
| options_[ge::ir_option::LOG_LEVEL] = log; | options_[ge::ir_option::LOG_LEVEL] = log; | ||||
| string input_shape = options_.find("input_shape") == options_.end() ? "" : options_["input_shape"]; | |||||
| string input_format = options_.find("input_format") == options_.end() ? "" : options_["input_format"]; | |||||
| string net_format = options_.find("net_format") == options_.end() ? "" : options_["net_format"]; | |||||
| string dynamic_batch_size = options_.find(ge::ir_option::DYNAMIC_BATCH_SIZE) == options_.end() | |||||
| ? "" | |||||
| : options_[ge::ir_option::DYNAMIC_BATCH_SIZE]; | |||||
| string dynamic_image_size = options_.find(ge::ir_option::DYNAMIC_IMAGE_SIZE) == options_.end() | |||||
| ? "" | |||||
| : options_[ge::ir_option::DYNAMIC_IMAGE_SIZE]; | |||||
| string dynamic_dims = | |||||
| options_.find(ge::ir_option::DYNAMIC_DIMS) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_DIMS]; | |||||
| string input_shape_range = | |||||
| options_.find(ge::INPUT_SHAPE_RANGE) == options_.end() ? "" : options_[ge::INPUT_SHAPE_RANGE]; | |||||
| string input_shape = GetParam(ge::ir_option::INPUT_SHAPE); | |||||
| string input_format = GetParam(ge::ir_option::INPUT_FORMAT); | |||||
| string dynamic_batch_size = GetParam(ge::ir_option::DYNAMIC_BATCH_SIZE); | |||||
| string dynamic_image_size = GetParam(ge::ir_option::DYNAMIC_IMAGE_SIZE); | |||||
| string dynamic_dims = GetParam(ge::ir_option::DYNAMIC_DIMS); | |||||
| string input_shape_range = GetParam(ge::INPUT_SHAPE_RANGE); | |||||
| auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape, | auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape, | ||||
| input_shape_range, input_format, is_dynamic_input_); | input_shape_range, input_format, is_dynamic_input_); | ||||
| if (status != ge::SUCCESS) { | if (status != ge::SUCCESS) { | ||||
| @@ -559,15 +556,12 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||||
| omg_context_.dynamic_image_size = dynamic_image_size; | omg_context_.dynamic_image_size = dynamic_image_size; | ||||
| omg_context_.dynamic_dims = dynamic_dims; | omg_context_.dynamic_dims = dynamic_dims; | ||||
| // check output_type | // check output_type | ||||
| std::string output_type = options_.find(ge::ir_option::OUTPUT_TYPE) == options_.end() | |||||
| ? "" | |||||
| : options_[ge::ir_option::OUTPUT_TYPE]; | |||||
| std::string output_type = GetParam(ge::ir_option::OUTPUT_TYPE); | |||||
| GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS, | GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS, | ||||
| return ge::GRAPH_PARAM_INVALID, "[Check][OutputType] failed!"); | return ge::GRAPH_PARAM_INVALID, "[Check][OutputType] failed!"); | ||||
| // check insert_op_conf | // check insert_op_conf | ||||
| std::string insert_op_conf = options_.find(ge::ir_option::INSERT_OP_FILE) == options_.end() | |||||
| ? "" | |||||
| : options_[ge::ir_option::INSERT_OP_FILE]; | |||||
| std::string insert_op_conf = GetParam(ge::ir_option::INSERT_OP_FILE); | |||||
| GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(insert_op_conf)) == ge::SUCCESS, | GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(insert_op_conf)) == ge::SUCCESS, | ||||
| return ge::GRAPH_PARAM_INVALID, "[Check][InsertOpConf] failed!"); | return ge::GRAPH_PARAM_INVALID, "[Check][InsertOpConf] failed!"); | ||||
| @@ -592,7 +586,7 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| // 4.parse and init Context with input shape format and net format info | // 4.parse and init Context with input shape format and net format info | ||||
| return this->InitDomiOmgContext(input_shape, input_format, net_format, is_dynamic_input_); | |||||
| return this->InitDomiOmgContext(input_shape, input_format, is_dynamic_input_); | |||||
| } | } | ||||
| void Impl::SetRtSocVersion() { | void Impl::SetRtSocVersion() { | ||||
| @@ -691,8 +685,7 @@ graphStatus Impl::BuildModel(const Graph &graph, const std::map<std::string, std | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| graphStatus Impl::InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, | |||||
| bool is_dynamic_input) { | |||||
| graphStatus Impl::InitDomiOmgContext(const string &input_shape, const string &input_format, bool is_dynamic_input) { | |||||
| // Clear omgcontext data first | // Clear omgcontext data first | ||||
| omg_context_.input_dims.clear(); | omg_context_.input_dims.clear(); | ||||
| omg_context_.user_input_dims.clear(); | omg_context_.user_input_dims.clear(); | ||||
| @@ -704,6 +704,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/infershape_pass_unittest.cc" | "graph/passes/infershape_pass_unittest.cc" | ||||
| "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | ||||
| "graph/passes/multi_batch_clone_pass_unittest.cc" | "graph/passes/multi_batch_clone_pass_unittest.cc" | ||||
| "graph/passes/subgraph_const_migration_pass_unittest.cc" | |||||
| "graph/passes/replace_with_empty_const_pass_unittest.cc" | "graph/passes/replace_with_empty_const_pass_unittest.cc" | ||||
| "graph/passes/link_gen_mask_nodes_pass_unittest.cc" | "graph/passes/link_gen_mask_nodes_pass_unittest.cc" | ||||
| "graph/passes/transpose_transdata_pass_unittest.cc" | "graph/passes/transpose_transdata_pass_unittest.cc" | ||||
| @@ -712,7 +713,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | ||||
| "graph/passes/reshape_recovery_pass_unittest.cc" | "graph/passes/reshape_recovery_pass_unittest.cc" | ||||
| "graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||||
| "graph/passes/hccl_continuous_pass_unittest.cc" | "graph/passes/hccl_continuous_pass_unittest.cc" | ||||
| "graph/passes/hccl_memcpy_pass_unittest.cc" | "graph/passes/hccl_memcpy_pass_unittest.cc" | ||||
| @@ -838,6 +839,7 @@ set(HYBRID_TEST_FILES | |||||
| "hybrid/node_executor/rts/rts_node_task_unittest.cc" | "hybrid/node_executor/rts/rts_node_task_unittest.cc" | ||||
| "hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc" | "hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc" | ||||
| "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | ||||
| "hybrid/node_executor/hccl/hccl_node_executor_unittest.cc" | |||||
| "hybrid/executor/hybrid_model_async_executor_unittest.cc" | "hybrid/executor/hybrid_model_async_executor_unittest.cc" | ||||
| "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | ||||
| "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | ||||
| @@ -0,0 +1,125 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| #include "graph/common/local_context.h" | |||||
| #include "graph/passes/subgraph_const_migration_pass.h" | |||||
| #include "inc/pass_manager.h" | |||||
| #include "register/op_registry.h" | |||||
| namespace ge { | |||||
| class UtestSubgraphConstMigrationPass : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| public: | |||||
| NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||||
| GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (auto i = 0; i < in_num; ++i) { | |||||
| op_desc->AddInputDesc(test_desc); | |||||
| } | |||||
| for (auto i = 0; i < out_num; ++i) { | |||||
| op_desc->AddOutputDesc(test_desc); | |||||
| } | |||||
| if (type == "Const") { | |||||
| uint64_t const_value = 101; | |||||
| auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||||
| AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight); | |||||
| } | |||||
| return graph->AddNode(op_desc); | |||||
| } | |||||
| void make_original_graph(const ComputeGraphPtr &graph) { | |||||
| auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
| { | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
| } | |||||
| auto const1 = MakeNode(graph, 0, 1, "const1", "Const"); | |||||
| { | |||||
| auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); | |||||
| GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor()); | |||||
| } | |||||
| auto const2 = MakeNode(graph, 0, 1, "const2", "Const"); | |||||
| { | |||||
| auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); | |||||
| GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor()); | |||||
| } | |||||
| auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||||
| GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||||
| } | |||||
| void make_multibatch_graph(const ComputeGraphPtr &graph) { | |||||
| auto index = MakeNode(graph, 1, 1, "index", "Data"); | |||||
| auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
| auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
| auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
| auto case1 = MakeNode(graph, 4, 1, "case", "Case"); | |||||
| GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2)); | |||||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3)); | |||||
| auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput"); | |||||
| GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
| AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); | |||||
| case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic); | |||||
| ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch"); | |||||
| make_original_graph(branch); | |||||
| for (int i = 0; i < 2; ++i) { | |||||
| std::string name("_ascend_mbatch_batch_" + std::to_string(i)); | |||||
| std::vector<NodePtr> input_nodes; | |||||
| std::vector<NodePtr> output_nodes; | |||||
| ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes); | |||||
| subgraph->SetName(name); | |||||
| subgraph->SetParentNode(case1); | |||||
| subgraph->SetParentGraph(graph); | |||||
| graph->AddSubgraph(subgraph->GetName(), subgraph); | |||||
| case1->GetOpDesc()->AddSubgraphName(name); | |||||
| case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName()); | |||||
| } | |||||
| } | |||||
| }; | |||||
| TEST_F(UtestSubgraphConstMigrationPass, subgraph_const_migration) { | |||||
| PassManager pass_manager; | |||||
| pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass); | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| make_multibatch_graph(graph); | |||||
| EXPECT_EQ(pass_manager.Run(graph), SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,108 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gmock/gmock.h> | |||||
| #include <gtest/gtest.h> | |||||
| #include <vector> | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "graph/runtime_inference_context.h" | |||||
| #include "hybrid/executor/subgraph_context.h" | |||||
| #include "hybrid/node_executor/hccl/hccl_node_executor.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| using namespace hybrid; | |||||
| class UtestHcclNodeExecutor : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| op_desc->SetStreamId(0); | |||||
| static int32_t index = 0; | |||||
| op_desc->SetId(index++); | |||||
| GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
| TensorUtils::SetSize(tensor, 64); | |||||
| vector<int64_t> input_offset; | |||||
| for (int i = 0; i < in_num; i++) { | |||||
| op_desc->AddInputDesc(tensor); | |||||
| input_offset.emplace_back(i * 64); | |||||
| } | |||||
| op_desc->SetInputOffset(input_offset); | |||||
| vector<int64_t> output_offset; | |||||
| for (int i = 0; i < out_num; i++) { | |||||
| op_desc->AddOutputDesc(tensor); | |||||
| output_offset.emplace_back(in_num * 64 + i * 64); | |||||
| } | |||||
| op_desc->SetOutputOffset(output_offset); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); | |||||
| RuntimeInferenceContext *ctx = nullptr; | |||||
| RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); | |||||
| Shape s({1, 3}); | |||||
| TensorDesc tensor_desc(s); | |||||
| Tensor tensor(tensor_desc); | |||||
| std::vector<uint8_t> data = {1, 2, 3, 4}; | |||||
| tensor.SetData(data); | |||||
| ctx->SetTensor(1, 0, tensor.Clone()); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| vector<HcomRemoteAccessAddrInfo> addr_infos; | |||||
| shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||||
| task->remote_index_ = {1, 0}; | |||||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
| Shape s2({1}); | |||||
| TensorDesc tensor_desc2(s2); | |||||
| Tensor tensor2(tensor_desc2); | |||||
| ctx->SetTensor(1, 0, tensor2.Clone()); | |||||
| task->ExtractTensor(*unique_task_context, addr_infos); | |||||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); | |||||
| } | |||||
| } // namespace ge | |||||