| @@ -961,9 +961,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInp | |||
| std::vector<DataType> input_data_type; | |||
| for (size_t i = 0; i < op->GetAllInputsSize(); ++i) { | |||
| GeTensorDescPtr input_tensor_desc = op->MutableInputDesc(i); | |||
| if (input_tensor_desc == nullptr) { | |||
| continue; | |||
| } | |||
| GE_IF_BOOL_EXEC(input_tensor_desc == nullptr, continue); | |||
| input_format.emplace_back(input_tensor_desc->GetFormat()); | |||
| input_shape.emplace_back(input_tensor_desc->GetShape().GetDims()); | |||
| input_data_type.emplace_back(input_tensor_desc->GetDataType()); | |||
| @@ -973,9 +972,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInp | |||
| std::vector<DataType> output_data_type; | |||
| for (size_t j = 0; j < op->GetOutputsSize(); ++j) { | |||
| GeTensorDescPtr output_tensor_desc = op->MutableOutputDesc(j); | |||
| if (output_tensor_desc == nullptr) { | |||
| continue; | |||
| } | |||
| GE_IF_BOOL_EXEC(output_tensor_desc == nullptr, continue); | |||
| output_format.emplace_back(output_tensor_desc->GetFormat()); | |||
| output_shape.emplace_back(output_tensor_desc->GetShape().GetDims()); | |||
| output_data_type.emplace_back(output_tensor_desc->GetDataType()); | |||
| @@ -854,7 +854,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||
| op_desc->GetName().c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_; | |||
| OmgContext &omg_context = impl_->omg_context_; | |||
| omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc); | |||
| if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) { | |||
| @@ -869,11 +869,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||
| if (!HasShapeRange(inputs) && compile_flag == kFuzzBuildPattern) { | |||
| fuzz_compile_flag = true; | |||
| } | |||
| if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag)) { | |||
| REPORT_CALL_ERROR("E19999", "set ATTR_NAME_FUZZ_BUILD failed for %s.", op_desc->GetName().c_str()); | |||
| GELOGE(FAILED, "[Set][ATTR_NAME_FUZZ_BUILD] Failed to set attr for %s.", op_desc->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag); | |||
| impl_->omg_context_.fuzz_compile_flag = fuzz_compile_flag; | |||
| // 1. Create ComputeGraph. | |||
| @@ -579,11 +579,8 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { | |||
| if (continuous_output) { | |||
| GE_CHK_STATUS_RET(GetNodeMemoryType(node, memory_type, "output"), | |||
| "[Get][MemType]fail for node:%s", node->GetName().c_str()); | |||
| ret = AssignContinuousOutputMemory(node, memory_type, continuous_type); | |||
| if (ret != ge::SUCCESS) { | |||
| GELOGE(ret, "[Assign][Memory:Continuous:Ouput]fail for node:%s", node->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| GE_CHK_STATUS_RET(AssignContinuousOutputMemory(node, memory_type, continuous_type), | |||
| "[Assign][Memory:Continuous:Output]fail for node:%s", node->GetName().c_str()); | |||
| } | |||
| } | |||
| // Assign continuous input memory in `reverse topo order` which stored before | |||
| @@ -1212,7 +1212,8 @@ Status StreamAllocator::SetActiveStreamsForLoop() { | |||
| for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| bool is_loop_active = false; | |||
| if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active) && is_loop_active) { | |||
| (void)AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active); | |||
| if (is_loop_active) { | |||
| vector<string> activated_label_list; | |||
| NodePtr pre_switch_node = FindSwitchNodeBeforeLoopActiveNode(node); | |||
| @@ -1668,42 +1668,23 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||
| }; | |||
| GE_MAKE_GUARD(release, callback); | |||
| // malloc sysOpInfoList in SysOpCheckInfo | |||
| status = rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM); | |||
| if (status != RT_ERROR_NONE) { | |||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||
| return RT_ERROR_TO_GE_STATUS(status); | |||
| } | |||
| GE_CHK_RT_RET(rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM)); | |||
| allocated_mem.push_back(d_req_op_list); | |||
| // malloc sysOpInfoList in SysOpCheckResp | |||
| status = rtMalloc(&d_res_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM); | |||
| if (status != RT_ERROR_NONE) { | |||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||
| return RT_ERROR_TO_GE_STATUS(status); | |||
| } | |||
| GE_CHK_RT_RET(rtMalloc(&d_res_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM)); | |||
| allocated_mem.push_back(d_res_op_list); | |||
| // malloc returnCodeList in SysOpCheckResp | |||
| status = rtMalloc(&d_ret_code_list, op_nums * sizeof(ReturnCode), RT_MEMORY_HBM); | |||
| if (status != RT_ERROR_NONE) { | |||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(ReturnCode), status); | |||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(ReturnCode), status); | |||
| return RT_ERROR_TO_GE_STATUS(status); | |||
| } | |||
| GE_CHK_RT_RET(rtMalloc(&d_ret_code_list, op_nums * sizeof(ReturnCode), RT_MEMORY_HBM)); | |||
| allocated_mem.push_back(d_ret_code_list); | |||
| for (const auto &op_type : aicpu_optype_list) { | |||
| SysOpInfo op_info; | |||
| // malloc op_type name in SysOpInfo | |||
| void *d_op_type_name = nullptr; | |||
| status = rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM); | |||
| if (status != RT_ERROR_NONE) { | |||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%lu, ret = 0x%X", op_type.length(), status); | |||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%lu, ret = 0x%X", op_type.length(), status); | |||
| return RT_ERROR_TO_GE_STATUS(status); | |||
| } | |||
| GE_CHK_RT_RET(rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM)); | |||
| allocated_mem.push_back(d_op_type_name); | |||
| GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.length(), op_type.c_str(), op_type.length(), RT_MEMCPY_HOST_TO_DEVICE)); | |||
| op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name)); | |||
| @@ -1716,12 +1697,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||
| SysOpInfo op_info; | |||
| // malloc op_type name in SysOpInfo | |||
| void *d_op_type_name = nullptr; | |||
| status = rtMalloc(&d_op_type_name, op_type.size(), RT_MEMORY_HBM); | |||
| if (status != RT_ERROR_NONE) { | |||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%lu, ret = 0x%X", op_type.length(), status); | |||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%lu, ret = 0x%X", op_type.size(), status); | |||
| return RT_ERROR_TO_GE_STATUS(status); | |||
| } | |||
| GE_CHK_RT_RET(rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM)); | |||
| allocated_mem.push_back(d_op_type_name); | |||
| GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.size(), op_type.c_str(), op_type.size(), RT_MEMCPY_HOST_TO_DEVICE)); | |||
| op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name)); | |||
| @@ -1745,12 +1722,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||
| op_check_info_res.sysOpInfoList = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_res_op_list)); | |||
| uint32_t args_size = sizeof(SysOpCheckInfo) + sizeof(SysOpCheckResp); | |||
| status = rtMalloc(&args, args_size, RT_MEMORY_HBM); | |||
| if (status != RT_ERROR_NONE) { | |||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%u, ret = 0x%X", args_size, status); | |||
| GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%u, ret = 0x%X", args_size, status); | |||
| return RT_ERROR_TO_GE_STATUS(status); | |||
| } | |||
| GE_CHK_RT_RET(rtMalloc(&args, args_size, RT_MEMORY_HBM)); | |||
| allocated_mem.push_back(args); | |||
| GE_CHK_RT(rtMemcpy(args, sizeof(SysOpCheckInfo), reinterpret_cast<void *>(&op_check_info_req), sizeof(SysOpCheckInfo), | |||
| RT_MEMCPY_HOST_TO_DEVICE)); | |||
| @@ -3532,9 +3532,8 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra | |||
| return ret; | |||
| } | |||
| GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); | |||
| if ((options_.build_mode == BUILD_MODE_TUNING) && | |||
| (options_.build_step == BUILD_STEP_BEFORE_UB_MATCH || options_.build_step == BUILD_STEP_AFTER_BUILDER || | |||
| options_.build_step == BUILD_STEP_AFTER_BUILDER_SUB)) { | |||
| std::set<string> build_steps = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB}; | |||
| if ((options_.build_mode == BUILD_MODE_TUNING) && (build_steps.count(options_.build_step) > 0)) { | |||
| GE_TIMESTAMP_START(ConvertGraphToFile); | |||
| std::string tuning_path; | |||
| (void) GetContext().GetOption(TUNING_PATH, tuning_path); | |||
| @@ -743,12 +743,12 @@ Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { | |||
| continue; | |||
| } | |||
| // ignore data / netoutput of subgraph | |||
| if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | |||
| continue; | |||
| } | |||
| if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | |||
| continue; | |||
| if (AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | |||
| if (node->GetType() == DATA || node->GetType() == NETOUTPUT) { | |||
| continue; | |||
| } | |||
| } | |||
| bool identity_reserved = false; | |||
| AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CANNOT_BE_DELETED, identity_reserved); | |||
| if (identity_reserved) { | |||
| @@ -366,11 +366,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr | |||
| // link input -> end | |||
| string end_name = kEndType + std::to_string(graph_info_.num_of_pld_end_); | |||
| auto end_op_desc = MakeShared<OpDesc>(end_graph->GetName() + "_" + end_name, END); | |||
| if (end_op_desc == nullptr) { | |||
| REPORT_CALL_ERROR("E19999", "New Memory for OpDesc failed."); | |||
| GELOGE(GRAPH_PARAM_INVALID, "[New][Memory] for OpDesc failed, pld_op_desc is nullptr."); | |||
| return FAILED; | |||
| } | |||
| GE_CHECK_NOTNULL(end_op_desc); | |||
| GE_IF_BOOL_EXEC(!AttrUtils::SetInt(end_op_desc, "peerIndex", graph_info_.num_of_pld_end_), | |||
| GELOGW("SetInt peerIndex failed");) | |||
| GE_IF_BOOL_EXEC(!AttrUtils::SetStr(end_op_desc, "parentOpType", dst_node->GetType()), | |||
| @@ -429,11 +426,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr | |||
| int64_t node_id = src_node_opdesc->GetId(); | |||
| const string pld_name = kPlaceHolderType + std::to_string(graph_info_.num_of_pld_end_); | |||
| auto pld_op_desc = MakeShared<OpDesc>(pld_graph->GetName() + "_" + pld_name, PLACEHOLDER); | |||
| if (pld_op_desc == nullptr) { | |||
| REPORT_CALL_ERROR("E19999", "New Memory for OpDesc failed."); | |||
| GELOGE(GRAPH_PARAM_INVALID, "[New][Memory] for OpDesc failed."); | |||
| return FAILED; | |||
| } | |||
| GE_CHECK_NOTNULL(pld_op_desc); | |||
| GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "peerIndex", graph_info_.num_of_pld_end_), | |||
| GELOGW("SetInt peerIndex failed");) | |||
| GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, "_peerNodeName", new_end_node->GetName()), | |||
| @@ -333,11 +333,8 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||
| during_pass_node_set.nodes_last.clear(); | |||
| } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); | |||
| if (re_pass_times == kMaxRePassTimes) { | |||
| GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); | |||
| } | |||
| GE_IF_BOOL_EXEC(re_pass_times == kMaxRePassTimes, GELOGW("re_pass_times should not come to %d", kMaxRePassTimes)); | |||
| GELOGD("All passes runs end"); | |||
| return SUCCESS; | |||
| } | |||
| Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { | |||
| @@ -41,9 +41,7 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||
| bool graph_change = false; | |||
| // 1. Add FP/BP flow ctrl (big cycle) | |||
| for (auto &node : compute_graph->GetDirectNode()) { | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| GE_IF_BOOL_EXEC(node == nullptr, continue); | |||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
| uint32_t true_stream_id = 0; | |||
| bool is_found = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, true_stream_id); | |||
| @@ -65,9 +63,7 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||
| // 2. Add special node flow ctrl. eg, IteratorGetNext. (small cycle) | |||
| // NOTE: Small cycle share the variables with big cycle. | |||
| for (auto &node : compute_graph->GetDirectNode()) { | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| GE_IF_BOOL_EXEC(node == nullptr, continue); | |||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
| bool need_cycle_flag = false; | |||
| bool is_found = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag); | |||
| @@ -164,9 +164,10 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra | |||
| data_nodes[parent_index] = node; | |||
| GELOGD("%s, index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, node->GetName().c_str()); | |||
| } else if ((node->GetType() == CONSTANT) && (node->GetOutDataAnchor(kZeroIndex) != nullptr)) { | |||
| } else if (node->GetType() == CONSTANT) { | |||
| set<string> peer_name_list; | |||
| const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex); | |||
| GE_IF_BOOL_EXEC(out_anchor == nullptr, continue); | |||
| for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
| const auto &peer_node = in_anchor->GetOwnerNode(); | |||
| // Trim subgraph node name prefix. | |||
| @@ -64,16 +64,19 @@ std::string TransOpBreadthFusionPass::GetNodeId(const int anchor_index, const No | |||
| GE_IF_BOOL_EXEC(node == nullptr || node->GetOpDesc() == nullptr, | |||
| REPORT_INNER_ERROR("E19999", "Param node or its op_desc is nullptr, check invalid"); | |||
| GELOGE(FAILED, "[Check][Param] Param node or its op_desc is nullptr"); return ""); | |||
| std::set<std::string> trans_shapes = { RESHAPE, EXPANDDIMS, SQUEEZE }; | |||
| std::set<std::string> trans_shape_and_format = { TRANSPOSE, TRANSPOSED, EXPANDDIMS }; | |||
| if (node->GetType() == CAST) { | |||
| trans_data_type = true; | |||
| } else if (node->GetType() == TRANSPOSE || node->GetType() == TRANSPOSED || node->GetType() == EXPANDDIMS) { | |||
| } else if (trans_shape_and_format.count(node->GetType()) > 0) { | |||
| trans_format = true; | |||
| trans_shape = true; | |||
| } else if (node->GetType() == TRANSDATA) { | |||
| trans_data_type = true; | |||
| trans_format = true; | |||
| trans_shape = true; | |||
| } else if (node->GetType() == RESHAPE || node->GetType() == EXPANDDIMS || node->GetType() == SQUEEZE) { | |||
| } else if (trans_shapes.count(node->GetType()) > 0) { | |||
| trans_shape = true; | |||
| } else if (node->GetType() == REFORMAT) { | |||
| trans_format = true; | |||
| @@ -71,15 +71,13 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||
| GELOGW("The number of input for slice must be %zu.", kSliceInputSize); | |||
| return NOT_CHANGED; | |||
| } | |||
| ConstGeTensorPtr x_ = input[kSliceInputIndexX]; | |||
| ConstGeTensorPtr begin = input[kSliceInputIndexBegin]; | |||
| ConstGeTensorPtr size = input[kSliceInputIndexSize]; | |||
| if (x_ == nullptr || begin == nullptr || size == nullptr) { | |||
| GELOGW("input tensor is nullptr."); | |||
| return NOT_CHANGED; | |||
| Status ret = CheckInput(x_, begin, size); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| // data type in input_x | |||
| auto data_type = x_->GetTensorDesc().GetDataType(); | |||
| // check supported | |||
| @@ -92,11 +90,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||
| if (!is_success) { | |||
| return NOT_CHANGED; | |||
| } | |||
| // check data type of begin and size | |||
| if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { | |||
| GELOGW("Data type of begin and size for slice are not DT_INT32."); | |||
| return NOT_CHANGED; | |||
| } | |||
| void *data = reinterpret_cast<void *>(const_cast<uint8_t *>(x_->GetData().data())); | |||
| int32_t *begin_data = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(begin->GetData().GetData())); | |||
| @@ -145,7 +139,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||
| return NOT_CHANGED; | |||
| } | |||
| Status ret = CheckOutputDims(output_dims, attr); | |||
| ret = CheckOutputDims(output_dims, attr); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| @@ -161,6 +155,19 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||
| return SUCCESS; | |||
| } | |||
| Status SliceKernel::CheckInput(const ConstGeTensorPtr x_, const ConstGeTensorPtr begin, const ConstGeTensorPtr size) { | |||
| if (x_ == nullptr || begin == nullptr || size == nullptr) { | |||
| GELOGW("input tensor is nullptr."); | |||
| return NOT_CHANGED; | |||
| } | |||
| // check data type of begin and size | |||
| if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { | |||
| GELOGW("Data type of begin and size for slice are not DT_INT32."); | |||
| return NOT_CHANGED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status SliceKernel::CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr) { | |||
| // check dim not all less than 0 | |||
| for (auto dim : output_dims) { | |||
| @@ -28,6 +28,7 @@ class SliceKernel : public Kernel { | |||
| vector<GeTensorPtr> &v_output) override; | |||
| Status CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr); | |||
| Status CheckInput(const ConstGeTensorPtr x_, const ConstGeTensorPtr begin, const ConstGeTensorPtr size); | |||
| }; | |||
| } // namespace ge | |||
| @@ -95,8 +95,8 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||
| } | |||
| op_info.dataType = iter->second; | |||
| HcclReduceOp op_type = HCCL_REDUCE_SUM; | |||
| if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMREDUCESCATTER || | |||
| op_desc->GetType() == HVDCALLBACKALLREDUCE || op_desc->GetType() == HCOMREDUCE) { | |||
| std::set<std::string> hccl_types = { HCOMALLREDUCE, HCOMREDUCESCATTER, HVDCALLBACKALLREDUCE, HCOMREDUCE }; | |||
| if (hccl_types.count(op_desc->GetType()) > 0) { | |||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | |||
| "[Get][HcclOperationType] failed for %s type:%s", op_desc->GetName().c_str(), | |||
| op_desc->GetType().c_str()); | |||
| @@ -283,6 +283,7 @@ class Impl { | |||
| void SetRtSocVersion(); | |||
| void UpdateThreadContext(); | |||
| void LoadOpsProto(); | |||
| std::string GetParam(const std::string ¶m); | |||
| public: | |||
| ge::GeGenerator generator_; | |||
| std::map<std::string, std::string> options_; | |||
| @@ -512,6 +513,10 @@ graphStatus Impl::CheckOptions(const std::map<std::string, std::string> &options | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::string Impl::GetParam(const std::string ¶m) { | |||
| return options_.find(param) == options_.end() ? "" : options_[param]; | |||
| } | |||
| graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::string> &options) { | |||
| // 1. check options | |||
| graphStatus ret = CheckOptions(options); | |||
| @@ -533,20 +538,14 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(ge::CheckLogParamValidAndSetLogLevel(log) == 0, GRAPH_PARAM_INVALID); | |||
| options_[ge::ir_option::LOG_LEVEL] = log; | |||
| string input_shape = options_.find("input_shape") == options_.end() ? "" : options_["input_shape"]; | |||
| string input_format = options_.find("input_format") == options_.end() ? "" : options_["input_format"]; | |||
| string net_format = options_.find("net_format") == options_.end() ? "" : options_["net_format"]; | |||
| string dynamic_batch_size = options_.find(ge::ir_option::DYNAMIC_BATCH_SIZE) == options_.end() | |||
| ? "" | |||
| : options_[ge::ir_option::DYNAMIC_BATCH_SIZE]; | |||
| string dynamic_image_size = options_.find(ge::ir_option::DYNAMIC_IMAGE_SIZE) == options_.end() | |||
| ? "" | |||
| : options_[ge::ir_option::DYNAMIC_IMAGE_SIZE]; | |||
| string dynamic_dims = | |||
| options_.find(ge::ir_option::DYNAMIC_DIMS) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_DIMS]; | |||
| string input_shape_range = | |||
| options_.find(ge::INPUT_SHAPE_RANGE) == options_.end() ? "" : options_[ge::INPUT_SHAPE_RANGE]; | |||
| string input_shape = GetParam("input_shape"); | |||
| string input_format = GetParam("input_format"); | |||
| string net_format = GetParam("net_format"); | |||
| string dynamic_batch_size = GetParam(ge::ir_option::DYNAMIC_BATCH_SIZE); | |||
| string dynamic_image_size = GetParam(ge::ir_option::DYNAMIC_IMAGE_SIZE); | |||
| string dynamic_dims = GetParam(ge::ir_option::DYNAMIC_DIMS); | |||
| string input_shape_range = GetParam(ge::INPUT_SHAPE_RANGE); | |||
| auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape, | |||
| input_shape_range, input_format, is_dynamic_input_); | |||
| if (status != ge::SUCCESS) { | |||
| @@ -559,15 +558,12 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||
| omg_context_.dynamic_image_size = dynamic_image_size; | |||
| omg_context_.dynamic_dims = dynamic_dims; | |||
| // check output_type | |||
| std::string output_type = options_.find(ge::ir_option::OUTPUT_TYPE) == options_.end() | |||
| ? "" | |||
| : options_[ge::ir_option::OUTPUT_TYPE]; | |||
| std::string output_type = GetParam(ge::ir_option::OUTPUT_TYPE); | |||
| GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS, | |||
| return ge::GRAPH_PARAM_INVALID, "[Check][OutputType] failed!"); | |||
| // check insert_op_conf | |||
| std::string insert_op_conf = options_.find(ge::ir_option::INSERT_OP_FILE) == options_.end() | |||
| ? "" | |||
| : options_[ge::ir_option::INSERT_OP_FILE]; | |||
| std::string insert_op_conf = GetParam(ge::ir_option::INSERT_OP_FILE); | |||
| GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(insert_op_conf)) == ge::SUCCESS, | |||
| return ge::GRAPH_PARAM_INVALID, "[Check][InsertOpConf] failed!"); | |||