From bdee8d1e058bd29ec778b90cc8fd7a3da8675e0d Mon Sep 17 00:00:00 2001 From: wjm Date: Sat, 12 Jun 2021 05:43:45 +0800 Subject: [PATCH] fix --- .../format_transfer_fracz_hwcn.cc | 16 ++++----- ge/common/helper/model_cache_helper.cc | 7 ++++ .../node_executor/hccl/hccl_node_executor.cc | 34 +++++++++++-------- 3 files changed, 33 insertions(+), 24 deletions(-) mode change 100755 => 100644 ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc old mode 100755 new mode 100644 index abe6263b..ed3a062c --- a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc +++ b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc @@ -17,6 +17,7 @@ #include "common/formats/format_transfers/format_transfer_fracz_hwcn.h" #include + #include #include "common/formats/utils/formats_definitions.h" @@ -35,8 +36,8 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { auto dst_shape = args.dst_shape; if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { std::string error = "Dose not support trans format from " + - FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + - FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); + FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + + FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return ACL_ERROR_GE_FORMAT_INVALID; } @@ -52,15 +53,13 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { if (!CheckShapeValid(src_shape, kFracZDimsNum)) { GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", ShapeToString(src_shape).c_str()); - REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", - ShapeToString(src_shape).c_str()); + REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", ShapeToString(src_shape).c_str()); return ACL_ERROR_GE_SHAPE_INVALID; } if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", ShapeToString(dst_shape).c_str()); - REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", - ShapeToString(dst_shape).c_str()); + REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", ShapeToString(dst_shape).c_str()); return ACL_ERROR_GE_SHAPE_INVALID; } int64_t c0 = GetCubeSizeByDataType(args.src_data_type); @@ -71,9 +70,8 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast(kNiSize)); if (src_shape.at(kFracZHWC1) != dst_shape.at(kHwcnH) * dst_shape.at(kHwcnW) * c1 || src_shape.at(kFracZC0) != c0 || src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { - std::string error = "Failed to check relationship between src shape" + - FmtToStr(ShapeToString(src_shape)) + " and dst shape" + - FmtToStr(ShapeToString(dst_shape)); + std::string error = "Failed to check relationship between src shape" + FmtToStr(ShapeToString(src_shape)) + + " and dst shape" + FmtToStr(ShapeToString(dst_shape)); GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); return ACL_ERROR_GE_SHAPE_INVALID; } diff --git a/ge/common/helper/model_cache_helper.cc b/ge/common/helper/model_cache_helper.cc index 9cd88ef1..0e6c6329 100755 --- a/ge/common/helper/model_cache_helper.cc +++ b/ge/common/helper/model_cache_helper.cc @@ -1679,6 +1679,13 @@ Status ModelCacheHelper::LoadOmModelFromCache(GeModelPtr &ge_model) const { GELOGW("LoadOmModelFromCache: Load model from file failed. ret = %u", ret); return ret; } + std::function callback = [&]() { + if (model_data.model_data != nullptr) { + delete[] reinterpret_cast(model_data.model_data); + model_data.model_data = nullptr; + } + }; + GE_MAKE_GUARD(release, callback); ModelHelper model_helper; ret = model_helper.LoadModel(model_data); diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc index 31f2c7a1..6be9849c 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -15,15 +15,16 @@ */ #include "hybrid/node_executor/hccl/hccl_node_executor.h" + #include "common/ge/plugin_manager.h" #include "common/math/math_util.h" #include "external/graph/attr_value.h" +#include "external/graph/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/util/hcom_util.h" #include "graph/utils/type_utils.h" -#include "external/graph/types.h" -#include "hybrid/executor/hybrid_execution_context.h" #include "hccl/hcom.h" +#include "hybrid/executor/hybrid_execution_context.h" #include "runtime/event.h" namespace ge { @@ -267,14 +268,16 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector do } Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { - void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, - ¶ms.recvcounts, ¶ms.rdispls}; + void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, ¶ms.recvcounts, + ¶ms.rdispls}; for (size_t i = 0; i < kAllToAllVInputNums; ++i) { auto addr = context.MutableInput(i); GE_CHECK_NOTNULL(addr); @@ -383,13 +386,14 @@ Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { } params.sendtype = iter->second; params.recvtype = iter->second; + params.group = nullptr; return SUCCESS; } Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams ¶ms) { - void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, - ¶ms.recvcounts, ¶ms.rdispls}; + void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, ¶ms.recvcounts, + ¶ms.rdispls}; for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) { auto addr = context.MutableInput(i); GE_CHECK_NOTNULL(addr); @@ -418,8 +422,9 @@ Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams params.recvtype = iter->second; int64_t addr_len = 0; - (void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); + (void)ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); params.addrLength = static_cast(addr_len); + params.group = nullptr; return SUCCESS; } @@ -428,7 +433,7 @@ Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::functionGetNodeName()); p_ctx->SetStatus(FAILED); @@ -460,7 +465,6 @@ Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function