Browse Source

cpplint magic num & define macro

tags/v1.2.0
陈劢 3 years ago
parent
commit
a7afa5683b
30 changed files with 302 additions and 196 deletions
  1. +2
    -2
      ge/common/auth/file_saver.cc
  2. +30
    -21
      ge/common/base64.h
  3. +43
    -25
      ge/common/formats/format_transfers/format_transfer_fractal_nz.cc
  4. +52
    -35
      ge/common/formats/format_transfers/format_transfer_fractal_zz.cc
  5. +13
    -12
      ge/common/formats/format_transfers/format_transfer_transpose.cc
  6. +9
    -0
      ge/common/formats/utils/formats_definitions.h
  7. +3
    -1
      ge/common/ge/tbe_plugin_manager.cc
  8. +43
    -40
      ge/common/util.cc
  9. +2
    -1
      ge/ge_runtime/runtime_model.cc
  10. +8
    -5
      ge/graph/load/new_model_manager/model_manager.cc
  11. +13
    -6
      ge/graph/load/new_model_manager/task_info/kernel_task_info.cc
  12. +10
    -8
      ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc
  13. +1
    -1
      ge/graph/load/new_model_manager/ts_mem_mall.h
  14. +6
    -6
      ge/graph/manager/graph_caching_allocator.cc
  15. +9
    -2
      ge/graph/manager/graph_caching_allocator.h
  16. +2
    -2
      ge/graph/manager/graph_var_manager.cc
  17. +1
    -0
      ge/graph/manager/graph_var_manager.h
  18. +13
    -6
      ge/graph/optimize/mem_rw_conflict_optimize.cc
  19. +2
    -1
      ge/graph/passes/data_pass.cc
  20. +2
    -1
      ge/graph/passes/for_pass.cc
  21. +3
    -1
      ge/graph/passes/mark_agnostic_pass.cc
  22. +4
    -2
      ge/graph/passes/merge_pass.cc
  23. +8
    -4
      ge/host_kernels/gather_v2_kernel.cc
  24. +6
    -3
      ge/host_kernels/range_kernel.cc
  25. +3
    -1
      ge/hybrid/common/npu_memory_allocator.cc
  26. +1
    -1
      ge/hybrid/executor/node_done_manager.cc
  27. +1
    -1
      ge/offline/main.cc
  28. +6
    -3
      ge/session/omg.cc
  29. +2
    -1
      ge/single_op/single_op.cc
  30. +4
    -4
      inc/framework/common/fmk_error_codes.h

+ 2
- 2
ge/common/auth/file_saver.cc View File

@@ -54,8 +54,8 @@ Status FileSaver::OpenFile(int32_t &fd, const std::string &file_path) {
Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) { Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID);
mmSsize_t write_count; mmSsize_t write_count;
uint32_t size_2g = ((uint32_t) 0x1 << 31);
uint32_t size_1g = ((uint32_t) 0x1 << 30);
uint32_t size_2g = 2147483648; // 0x1 << 31
uint32_t size_1g = 1073741824; // 0x1 << 30
// Write data // Write data
if (size > size_2g) { if (size > size_2g) {
auto seek = reinterpret_cast<uint8_t *>(const_cast<void *>(data)); auto seek = reinterpret_cast<uint8_t *>(const_cast<void *>(data));


+ 30
- 21
ge/common/base64.h View File

@@ -25,32 +25,38 @@


namespace ge { namespace ge {
namespace { namespace {
const char* kBase64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
const char *kBase64Chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
const char kEqualSymbol = '='; const char kEqualSymbol = '=';
const size_t kBase64CharsNum = 64; const size_t kBase64CharsNum = 64;
const size_t kThreeByteOneGroup = 3; const size_t kThreeByteOneGroup = 3;
const size_t kFourByteOneGroup = 4; const size_t kFourByteOneGroup = 4;
}
const size_t kThreeByteOneGroupIndex0 = 0;
const size_t kThreeByteOneGroupIndex1 = 1;
const size_t kThreeByteOneGroupIndex2 = 2;
const size_t kFourByteOneGroupIndex0 = 0;
const size_t kFourByteOneGroupIndex1 = 1;
const size_t kFourByteOneGroupIndex2 = 2;
const size_t kFourByteOneGroupIndex3 = 3;
} // namespace


namespace base64 { namespace base64 {
static inline bool IsBase64Char(const char &c) {
return (isalnum(c) || (c == '+') || (c == '/'));
}
static inline bool IsBase64Char(const char &c) { return (isalnum(c) || (c == '+') || (c == '/')); }


static std::string EncodeToBase64(const std::string &raw_data) { static std::string EncodeToBase64(const std::string &raw_data) {
size_t encode_length = raw_data.size() / kThreeByteOneGroup * kFourByteOneGroup; size_t encode_length = raw_data.size() / kThreeByteOneGroup * kFourByteOneGroup;
encode_length += raw_data.size() % kThreeByteOneGroup == 0 ? 0 : kFourByteOneGroup; encode_length += raw_data.size() % kThreeByteOneGroup == 0 ? 0 : kFourByteOneGroup;
size_t raw_data_index = 0 ;
size_t raw_data_index = 0;
size_t encode_data_index = 0; size_t encode_data_index = 0;
std::string encode_data; std::string encode_data;
encode_data.resize(encode_length); encode_data.resize(encode_length);


for (; raw_data_index + kThreeByteOneGroup <= raw_data.size(); raw_data_index += kThreeByteOneGroup) { for (; raw_data_index + kThreeByteOneGroup <= raw_data.size(); raw_data_index += kThreeByteOneGroup) {
auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]); auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]);
auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + 1]);
auto char_3 = static_cast<uint8_t>(raw_data[raw_data_index + 2]);
auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + kThreeByteOneGroupIndex1]);
auto char_3 = static_cast<uint8_t>(raw_data[raw_data_index + kThreeByteOneGroupIndex2]);
encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u];
encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)]; encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)];
encode_data[encode_data_index++] = kBase64Chars[((char_2 << 2u) & 0x3c) | (char_3 >> 6u)]; encode_data[encode_data_index++] = kBase64Chars[((char_2 << 2u) & 0x3c) | (char_3 >> 6u)];
@@ -80,8 +86,7 @@ static std::string EncodeToBase64(const std::string &raw_data) {
#pragma GCC diagnostic ignored "-Wunused-function" #pragma GCC diagnostic ignored "-Wunused-function"
static Status DecodeFromBase64(const std::string &base64_data, std::string &decode_data) { static Status DecodeFromBase64(const std::string &base64_data, std::string &decode_data) {
if (base64_data.size() % kFourByteOneGroup != 0) { if (base64_data.size() % kFourByteOneGroup != 0) {
GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu",
base64_data.size());
GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu", base64_data.size());
return PARAM_INVALID; return PARAM_INVALID;
} }
decode_data.clear(); decode_data.clear();
@@ -92,10 +97,10 @@ static Status DecodeFromBase64(const std::string &base64_data, std::string &deco
return static_cast<uint8_t>(std::distance(kBase64Chars, char_pos)) & 0xff; return static_cast<uint8_t>(std::distance(kBase64Chars, char_pos)) & 0xff;
}; };


for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += 4) {
for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += kFourByteOneGroup) {
for (size_t i = 0; i < kFourByteOneGroup; ++i) { for (size_t i = 0; i < kFourByteOneGroup; ++i) {
if (base64_data[input_data_index + i] == kEqualSymbol && if (base64_data[input_data_index + i] == kEqualSymbol &&
input_data_index >= base64_data_len - 4 && i > 1) {
input_data_index >= base64_data_len - kFourByteOneGroup && i > 1) {
byte_4[i] = kBase64CharsNum; byte_4[i] = kBase64CharsNum;
} else if (IsBase64Char(base64_data[input_data_index + i])) { } else if (IsBase64Char(base64_data[input_data_index + i])) {
byte_4[i] = FindCharInBase64Chars(base64_data[input_data_index + i]); byte_4[i] = FindCharInBase64Chars(base64_data[input_data_index + i]);
@@ -104,19 +109,23 @@ static Status DecodeFromBase64(const std::string &base64_data, std::string &deco
return PARAM_INVALID; return PARAM_INVALID;
} }
} }
decode_data += static_cast<char>((byte_4[0] << 2u) + ((byte_4[1] & 0x30) >> 4u));
if (byte_4[2] >= kBase64CharsNum){
decode_data +=
static_cast<char>((byte_4[kFourByteOneGroupIndex0] << 2u) + ((byte_4[kFourByteOneGroupIndex1] & 0x30) >> 4u));
if (byte_4[kFourByteOneGroupIndex2] >= kBase64CharsNum) {
break; break;
} else if (byte_4[3] >= kBase64CharsNum) {
decode_data += static_cast<char>(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u));
} else if (byte_4[kFourByteOneGroupIndex3] >= kBase64CharsNum) {
decode_data += static_cast<char>(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) +
((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u));
break; break;
} }
decode_data += static_cast<char>(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u));
decode_data += static_cast<char>(((byte_4[2] & 0x03) << 6u) + byte_4[3]);
decode_data += static_cast<char>(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) +
((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u));
decode_data +=
static_cast<char>(((byte_4[kFourByteOneGroupIndex2] & 0x03) << 6u) + byte_4[kFourByteOneGroupIndex3]);
} }
return SUCCESS; return SUCCESS;
} }
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
}
} // namespace base64
} // namespace ge } // namespace ge
#endif // GE_COMMON_BASE64_H_ #endif // GE_COMMON_BASE64_H_

+ 43
- 25
ge/common/formats/format_transfers/format_transfer_fractal_nz.cc View File

@@ -23,12 +23,30 @@
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "framework/common/types.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


namespace ge { namespace ge {
namespace formats { namespace formats {
namespace { namespace {
const int kDimSize4D = 4; const int kDimSize4D = 4;

const size_t kSingleDim = 1;

const size_t kNdDimIndexN = 0;
const size_t kNdDimIndexH = 1;
const size_t kNdDimIndexW = 2;

const size_t kDimDValueBNdFNz = 2; // dim d-value between Nd and FractalZz

const size_t kNdDimCountBackwardsW = 1;
const size_t kNdDimCountBackwardsWH = 2;

const size_t kFNzDimCountBackwardsW0 = 1;
const size_t kFNzDimCountBackwardsW0H0 = 2;
const size_t kFNzDimCountBackwardsW0H0H1 = 3;
const size_t kFNzDimCountBackwardsW0H0H1W1 = 4;

bool IsDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0; } bool IsDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0; }


using ShapeVector = std::vector<int64_t>; using ShapeVector = std::vector<int64_t>;
@@ -60,14 +78,14 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap
auto w0 = GetCubeSizeByDataType(data_type); auto w0 = GetCubeSizeByDataType(data_type);
int64_t h0 = kCubeSize; int64_t h0 = kCubeSize;
switch (src_shape.size()) { switch (src_shape.size()) {
case 1:
dst_shape.push_back(Ceil(src_shape[0], w0));
dst_shape.push_back(1);
case kSingleDim:
dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0));
dst_shape.push_back(DIM_DEFAULT_VALUE);
dst_shape.push_back(h0); dst_shape.push_back(h0);
dst_shape.push_back(w0); dst_shape.push_back(w0);
hw_shape.push_back(1);
hw_shape.push_back(1);
hw_shape.push_back(src_shape[0]);
hw_shape.push_back(DIM_DEFAULT_VALUE);
hw_shape.push_back(DIM_DEFAULT_VALUE);
hw_shape.push_back(src_shape[kNdDimIndexN]);
if (!IsShapeValid(dst_shape)) { if (!IsShapeValid(dst_shape)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
@@ -76,17 +94,17 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap
default: default:
auto size = src_shape.size(); auto size = src_shape.size();
int64_t times = 1; int64_t times = 1;
for (size_t i = 0; i != size - 2; i++) {
for (size_t i = 0; i != size - kDimDValueBNdFNz; i++) {
dst_shape.push_back(src_shape[i]); dst_shape.push_back(src_shape[i]);
times *= src_shape[i]; times *= src_shape[i];
} }
dst_shape.push_back(Ceil(src_shape[size - 1], w0));
dst_shape.push_back(Ceil(src_shape[size - 2], h0));
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0));
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0));
dst_shape.push_back(h0); dst_shape.push_back(h0);
dst_shape.push_back(w0); dst_shape.push_back(w0);
hw_shape.push_back(times); hw_shape.push_back(times);
hw_shape.push_back(src_shape[size - 2]);
hw_shape.push_back(src_shape[size - 1]);
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]);
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]);
if (!IsShapeValid(dst_shape)) { if (!IsShapeValid(dst_shape)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
@@ -128,16 +146,16 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con
} }


// src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D // src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D
auto times = hw_shape.at(0);
auto h = hw_shape.at(1);
auto w = hw_shape.at(2);
auto times = hw_shape.at(kNdDimIndexN);
auto h = hw_shape.at(kNdDimIndexH);
auto w = hw_shape.at(kNdDimIndexW);
auto hw = h * w; auto hw = h * w;


auto shape_size = args.dst_shape.size(); auto shape_size = args.dst_shape.size();
auto w1 = args.dst_shape[shape_size - 4];
auto h1 = args.dst_shape[shape_size - 3];
auto h0 = args.dst_shape[shape_size - 2];
auto w0 = args.dst_shape[shape_size - 1];
auto w1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1];
auto h1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1];
auto h0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0];
auto w0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0];
auto h1h0 = h1 * h0; auto h1h0 = h1 * h0;
auto h1h0w0 = h1h0 * w0; auto h1h0w0 = h1h0 * w0;
auto w1h1h0w0 = w1 * h1h0w0; auto w1h1h0w0 = w1 * h1h0w0;
@@ -198,16 +216,16 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con
return OUT_OF_MEMORY; return OUT_OF_MEMORY;
} }


auto times = dst_hw_shape.at(0);
auto h = dst_hw_shape.at(1);
auto w = dst_hw_shape.at(2);
auto times = dst_hw_shape.at(kNdDimIndexN);
auto h = dst_hw_shape.at(kNdDimIndexH);
auto w = dst_hw_shape.at(kNdDimIndexW);
auto hw = h * w; auto hw = h * w;


auto shape_size = args.src_shape.size(); auto shape_size = args.src_shape.size();
auto w1 = args.src_shape[shape_size - 4];
auto h1 = args.src_shape[shape_size - 3];
auto h0 = args.src_shape[shape_size - 2];
auto w0 = args.src_shape[shape_size - 1];
auto w1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1];
auto h1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1];
auto h0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0];
auto w0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0];
auto h1h0 = h1 * h0; auto h1h0 = h1 * h0;
auto h1h0w0 = h1h0 * w0; auto h1h0w0 = h1h0 * w0;
auto w1h1h0w0 = w1 * h1h0w0; auto w1h1h0w0 = w1 * h1h0w0;


+ 52
- 35
ge/common/formats/format_transfers/format_transfer_fractal_zz.cc View File

@@ -23,12 +23,29 @@
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "framework/common/types.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


namespace ge { namespace ge {
namespace formats { namespace formats {
namespace { namespace {
const int kDimSize4D = 4; const int kDimSize4D = 4;

const size_t kSingleDim = 1;

const size_t kNdDimIndexN = 0;
const size_t kNdDimIndexH = 1;
const size_t kNdDimIndexW = 2;

const size_t kDimDValueBNdFZz = 2; // dim d-value between Nd and FractalZz

const size_t kNdDimCountBackwardsW = 1;
const size_t kNdDimCountBackwardsWH = 2;

const size_t kFZzDimCountBackwardsW0 = 1;
const size_t kFZzDimCountBackwardsW0H0 = 2;
const size_t kFZzDimCountBackwardsW0H0W1 = 3;
const size_t kFZzDimCountBackwardsW0H0W1H1 = 4;
bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; } bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; }


using ShapeVector = std::vector<int64_t>; using ShapeVector = std::vector<int64_t>;
@@ -40,8 +57,8 @@ bool CheckShape(Format format, const ShapeVector &shape) {
case FORMAT_NHWC: case FORMAT_NHWC:
return CheckShapeValid(shape, kDimSize4D); return CheckShapeValid(shape, kDimSize4D);
default: default:
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) +
" and FORMAT_FRACTAL_ZZ is not supported.";
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) +
" and FORMAT_FRACTAL_ZZ is not supported.";
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
return false; return false;
} }
@@ -60,14 +77,14 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap
auto w0 = GetCubeSizeByDataType(data_type); auto w0 = GetCubeSizeByDataType(data_type);
auto h0 = GetCubeSizeByDataType(data_type); auto h0 = GetCubeSizeByDataType(data_type);
switch (src_shape.size()) { switch (src_shape.size()) {
case 1:
dst_shape.push_back(1);
dst_shape.push_back(Ceil(src_shape[0], w0));
case kSingleDim:
dst_shape.push_back(DIM_DEFAULT_VALUE);
dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0));
dst_shape.push_back(h0); dst_shape.push_back(h0);
dst_shape.push_back(w0); dst_shape.push_back(w0);
hw_shape.push_back(1);
hw_shape.push_back(1);
hw_shape.push_back(src_shape[0]);
hw_shape.push_back(DIM_DEFAULT_VALUE);
hw_shape.push_back(DIM_DEFAULT_VALUE);
hw_shape.push_back(src_shape[kNdDimIndexN]);
if (!IsShapeValid(dst_shape)) { if (!IsShapeValid(dst_shape)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
@@ -76,17 +93,17 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap
default: default:
auto size = src_shape.size(); auto size = src_shape.size();
int64_t times = 1; int64_t times = 1;
for (size_t i = 0; i != size - 2; i++) {
for (size_t i = 0; i != size - kDimDValueBNdFZz; i++) {
dst_shape.push_back(src_shape[i]); dst_shape.push_back(src_shape[i]);
times *= src_shape[i]; times *= src_shape[i];
} }
dst_shape.push_back(Ceil(src_shape[size - 2], h0));
dst_shape.push_back(Ceil(src_shape[size - 1], w0));
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0));
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0));
dst_shape.push_back(h0); dst_shape.push_back(h0);
dst_shape.push_back(w0); dst_shape.push_back(w0);
hw_shape.push_back(times); hw_shape.push_back(times);
hw_shape.push_back(src_shape[size - 2]);
hw_shape.push_back(src_shape[size - 1]);
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]);
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]);
if (!IsShapeValid(dst_shape)) { if (!IsShapeValid(dst_shape)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
@@ -127,16 +144,16 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
return OUT_OF_MEMORY; return OUT_OF_MEMORY;
} }
// The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D
auto times = hw_shape.at(0);
auto h = hw_shape.at(1);
auto w = hw_shape.at(2);
auto times = hw_shape.at(kNdDimIndexN);
auto h = hw_shape.at(kNdDimIndexH);
auto w = hw_shape.at(kNdDimIndexW);
auto hw = h * w; auto hw = h * w;


auto shape_size = args.dst_shape.size(); auto shape_size = args.dst_shape.size();
auto h1 = args.dst_shape[shape_size - 4];
auto w1 = args.dst_shape[shape_size - 3];
auto h0 = args.dst_shape[shape_size - 2];
auto w0 = args.dst_shape[shape_size - 1];
auto h1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1];
auto w1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1];
auto h0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0];
auto w0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0];
auto h0w0 = h0 * w0; auto h0w0 = h0 * w0;
auto w1h0w0 = w1 * h0w0; auto w1h0w0 = w1 * h0w0;
auto h1w1h0w0 = h1 * w1h0w0; auto h1w1h0w0 = h1 * w1h0w0;
@@ -155,8 +172,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
auto src_offset = (src_h_head + w1_idx * w0) * size; auto src_offset = (src_h_head + w1_idx * w0) * size;
auto dst_offset = (h0_head + w1_idx * h0w0) * size; auto dst_offset = (h0_head + w1_idx * h0w0) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size * w0)); static_cast<size_t>(size * w0));
if (ret != EOK) { if (ret != EOK) {
@@ -171,8 +188,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
auto src_offset = (src_h_head + src_w_idx) * size; auto src_offset = (src_h_head + src_w_idx) * size;
auto dst_offset = (w0_head + w0_idx) * size; auto dst_offset = (w0_head + w0_idx) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size)); static_cast<size_t>(size));
if (ret != EOK) { if (ret != EOK) {
@@ -205,16 +222,16 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
} }


// The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D
auto times = dst_hw_shape.at(0);
auto h = dst_hw_shape.at(1);
auto w = dst_hw_shape.at(2);
auto times = dst_hw_shape.at(kNdDimIndexN);
auto h = dst_hw_shape.at(kNdDimIndexH);
auto w = dst_hw_shape.at(kNdDimIndexW);
auto hw = h * w; auto hw = h * w;


auto shape_size = args.src_shape.size(); auto shape_size = args.src_shape.size();
auto h1 = args.src_shape[shape_size - 4];
auto w1 = args.src_shape[shape_size - 3];
auto h0 = args.src_shape[shape_size - 2];
auto w0 = args.src_shape[shape_size - 1];
auto h1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1];
auto w1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1];
auto h0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0];
auto w0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0];
auto h0w0 = h0 * w0; auto h0w0 = h0 * w0;
auto w1h0w0 = w1 * h0w0; auto w1h0w0 = w1 * h0w0;
auto h1w1h0w0 = h1 * w1h0w0; auto h1w1h0w0 = h1 * w1h0w0;
@@ -233,8 +250,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
auto src_offset = (h0_head + w1_idx * h0w0) * size; auto src_offset = (h0_head + w1_idx * h0w0) * size;
auto dst_offset = (dst_h_head + w1_idx * w0) * size; auto dst_offset = (dst_h_head + w1_idx * w0) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size * w0)); static_cast<size_t>(size * w0));
if (ret != EOK) { if (ret != EOK) {
@@ -249,8 +266,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
auto dst_w_idx = w1_head + w0_idx; auto dst_w_idx = w1_head + w0_idx;
auto dst_offset = (dst_h_head + dst_w_idx) * size; auto dst_offset = (dst_h_head + dst_w_idx) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size)); static_cast<size_t>(size));
if (ret != EOK) { if (ret != EOK) {


+ 13
- 12
ge/common/formats/format_transfers/format_transfer_transpose.cc View File

@@ -19,6 +19,7 @@
#include <securec.h> #include <securec.h>
#include <memory> #include <memory>


#include "common/formats/utils/formats_definitions.h"
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
@@ -29,21 +30,21 @@ namespace formats {
namespace { namespace {
std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{
{FORMAT_NCHW, {FORMAT_NCHW,
{{FORMAT_NHWC, std::vector<int64_t>({0, 2, 3, 1})},
{FORMAT_HWCN, std::vector<int64_t>({2, 3, 1, 0})},
{FORMAT_CHWN, std::vector<int64_t>({1, 2, 3, 0})}}},
{{FORMAT_NHWC, std::vector<int64_t>({kNchwN, kNchwH, kNchwW, kNchwC})},
{FORMAT_HWCN, std::vector<int64_t>({kNchwH, kNchwW, kNchwC, kNchwN})},
{FORMAT_CHWN, std::vector<int64_t>({kNchwC, kNchwH, kNchwW, kNchwN})}}},
{FORMAT_NHWC, {FORMAT_NHWC,
{{FORMAT_NCHW, std::vector<int64_t>({0, 3, 1, 2})},
{FORMAT_CHWN, std::vector<int64_t>({3, 1, 2, 0})},
{FORMAT_HWCN, std::vector<int64_t>({1, 2, 3, 0})}}},
{{FORMAT_NCHW, std::vector<int64_t>({kNhwcN, kNhwcC, kNhwcH, kNhwcW})},
{FORMAT_CHWN, std::vector<int64_t>({kNhwcC, kNhwcH, kNhwcW, kNhwcN})},
{FORMAT_HWCN, std::vector<int64_t>({kNhwcH, kNhwcW, kNhwcC, kNhwcN})}}},
{FORMAT_HWCN, {FORMAT_HWCN,
{{FORMAT_NCHW, std::vector<int64_t>({3, 2, 0, 1})},
{FORMAT_NHWC, std::vector<int64_t>({3, 0, 1, 2})},
{FORMAT_CHWN, std::vector<int64_t>({2, 0, 1, 3})}}},
{{FORMAT_NCHW, std::vector<int64_t>({kHwcnN, kHwcnC, kHwcnH, kHwcnW})},
{FORMAT_NHWC, std::vector<int64_t>({kHwcnN, kHwcnH, kHwcnW, kHwcnC})},
{FORMAT_CHWN, std::vector<int64_t>({kHwcnC, kHwcnH, kHwcnW, kHwcnN})}}},
{FORMAT_CHWN, {FORMAT_CHWN,
{{FORMAT_NCHW, std::vector<int64_t>({3, 0, 1, 2})},
{FORMAT_NHWC, std::vector<int64_t>({3, 1, 2, 0})},
{FORMAT_HWCN, std::vector<int64_t>({1, 2, 0, 3})}}},
{{FORMAT_NCHW, std::vector<int64_t>({kChwnN, kChwnC, kChwnH, kChwnW})},
{FORMAT_NHWC, std::vector<int64_t>({kChwnN, kChwnH, kChwnW, kChwnC})},
{FORMAT_HWCN, std::vector<int64_t>({kChwnH, kChwnW, kChwnC, kChwnN})}}},
}; };


bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) {


+ 9
- 0
ge/common/formats/utils/formats_definitions.h View File

@@ -23,6 +23,7 @@ static const int kCubeSize = 16;
static const int kNiSize = 16; static const int kNiSize = 16;
static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL;



enum NchwDimIndex { enum NchwDimIndex {
kNchwN, kNchwN,
kNchwC, kNchwC,
@@ -47,6 +48,14 @@ enum HwcnDimIndex {
kHwcnDimsNum kHwcnDimsNum
}; };


enum ChwnDimIndex {
kChwnC,
kChwnH,
kChwnW,
kChwnN,
kChwnDimsNum
};

enum Nc1hwc0DimIndex { enum Nc1hwc0DimIndex {
kNc1hwc0N, kNc1hwc0N,
kNc1hwc0C1, kNc1hwc0C1,


+ 3
- 1
ge/common/ge/tbe_plugin_manager.cc View File

@@ -37,6 +37,8 @@
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


namespace ge { namespace ge {
const int kBaseInt = 10;

std::map<string, string> TBEPluginManager::options_ = {}; std::map<string, string> TBEPluginManager::options_ = {};


// Get Singleton Instance // Get Singleton Instance
@@ -155,7 +157,7 @@ void TBEPluginManager::GetCustomOpPath(std::string &customop_path) {
domi::FrameworkType type = domi::TENSORFLOW; domi::FrameworkType type = domi::TENSORFLOW;
auto it = options_.find(FRAMEWORK_TYPE); auto it = options_.find(FRAMEWORK_TYPE);
if (it != options_.end()) { if (it != options_.end()) {
type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10));
type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, kBaseInt));
} }
fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); fmk_type = ge::TypeUtils::FmkTypeToSerialString(type);
GELOGI("Framework type is %s.", fmk_type.c_str()); GELOGI("Framework type is %s.", fmk_type.c_str());


+ 43
- 40
ge/common/util.cc View File

@@ -51,14 +51,15 @@ namespace {
* If such an exception is encountered during operation, * If such an exception is encountered during operation,
* the proto file can be divided into several small files or the limit value can be increased. * the proto file can be divided into several small files or the limit value can be increased.
*/ */
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M
const int kFileSizeOutLimitedOrOpenFailed = -1;
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 1073741824; // 536870912 * 2 536870912 represent 512M


/// The maximum length of the file. /// The maximum length of the file.
const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now
const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now
const int kMaxBuffSize = 256; const int kMaxBuffSize = 256;
const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character";
constexpr uint32_t kMaxConfigFileByte = 10 * 1024 * 1024;
constexpr uint32_t kMaxConfigFileByte = 10485760; // 10 * 1024 * 1024
} // namespace } // namespace


namespace ge { namespace ge {
@@ -76,7 +77,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co
std::string real_path = RealPath(file); std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file);


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == kFileSizeOutLimitedOrOpenFailed, return false,
"file size not valid.");


std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary);
if (!fs.is_open()) { if (!fs.is_open()) {
@@ -118,20 +120,20 @@ long GetFileLength(const std::string &input_file) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str());
unsigned long long file_length = 0; unsigned long long file_length = 0;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)});
return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno));
mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)});
return kFileSizeOutLimitedOrOpenFailed, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno));


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0),
ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file});
return -1, "File[%s] size is 0, not valid.", input_file.c_str()); return -1, "File[%s] size is 0, not valid.", input_file.c_str());


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit,
ErrorManager::GetInstance().ATCReportErrMessage(
"E19016", {"filepath", "filesize", "maxlen"},
{input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)});
return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length,
kMaxFileSizeLimit);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage(
"E19016", {"filepath", "filesize", "maxlen"},
{input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)});
return kFileSizeOutLimitedOrOpenFailed, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length,
kMaxFileSizeLimit);
return static_cast<long>(file_length); return static_cast<long>(file_length);
} }


@@ -187,7 +189,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co
std::streamsize size = file.tellg(); std::streamsize size = file.tellg();


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast<int64_t >(kMaxFileSizeLimit), file.close();
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast<int64_t>(kMaxFileSizeLimit), file.close();
return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit); return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit);


file.seekg(0, std::ios::beg); // [no need to check value] file.seekg(0, std::ios::beg); // [no need to check value]
@@ -210,8 +212,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std::
GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty.");
auto dir_path_len = directory_path.length(); auto dir_path_len = directory_path.length();
if (dir_path_len >= MMPA_MAX_PATH) { if (dir_path_len >= MMPA_MAX_PATH) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {directory_path, std::to_string(MMPA_MAX_PATH)});
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"},
{directory_path, std::to_string(MMPA_MAX_PATH)});
GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH); GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH);
return -1; return -1;
} }
@@ -224,8 +226,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std::
if (ret != 0) { if (ret != 0) {
if (errno != EEXIST) { if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.",
directory_path.c_str());
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret; return ret;
} }
} }
@@ -265,7 +266,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch


std::string real_path = RealPath(file); std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage(
"E19000", {"path", "errmsg"}, {file, strerror(errno)});
"E19000", {"path", "errmsg"}, {file, strerror(errno)});
return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno)); return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno));


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");
@@ -301,13 +302,13 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha
google::protobuf::io::IstreamInputStream input(&fs); google::protobuf::io::IstreamInputStream input(&fs);
bool ret = google::protobuf::TextFormat::Parse(&input, message); bool ret = google::protobuf::TextFormat::Parse(&input, message);
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
!ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file."));
!ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file."));


return ret; return ret;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() {
mmTimeval tv {};
mmTimeval tv{};
int ret = mmGetTimeOfDay(&tv, nullptr); int ret = mmGetTimeOfDay(&tv, nullptr);
GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret);
auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds
@@ -315,7 +316,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp()
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() {
mmTimeval tv {};
mmTimeval tv{};
int ret = mmGetTimeOfDay(&tv, nullptr); int ret = mmGetTimeOfDay(&tv, nullptr);
GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret);
auto total_use_time = tv.tv_sec; // seconds auto total_use_time = tv.tv_sec; // seconds
@@ -350,8 +351,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int6
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH,
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH);
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"},
{path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH);


// Nullptr is returned when the path does not exist or there is no permission // Nullptr is returned when the path does not exist or there is no permission
// Return absolute path when path is accessible // Return absolute path when path is accessible
@@ -385,16 +387,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const
// Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores
// File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.)
#ifdef __GNUC__ #ifdef __GNUC__
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
#else #else
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
#endif #endif


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(real_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, real_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason);
!ValidateStr(real_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, real_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason);


// The absolute path points to a file that is not readable // The absolute path points to a file that is not readable
if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) {
@@ -416,24 +418,25 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const
} }


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH,
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), MMPA_MAX_PATH);
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(),
MMPA_MAX_PATH);


// A regular matching expression to verify the validity of the input file path // A regular matching expression to verify the validity of the input file path
// Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores
// File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.)
#ifdef __GNUC__ #ifdef __GNUC__
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
#else #else
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
#endif #endif


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(file_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason);
!ValidateStr(file_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason);


std::string real_path = RealPath(file_path.c_str()); std::string real_path = RealPath(file_path.c_str());
// Can get absolute path (file exists) // Can get absolute path (file exists)


+ 2
- 1
ge/ge_runtime/runtime_model.cc View File

@@ -28,6 +28,7 @@


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {
const int kOffsetUnit = 8;
RuntimeModel::~RuntimeModel() { RuntimeModel::~RuntimeModel() {
GELOGI("RuntimeModel destructor start"); GELOGI("RuntimeModel destructor start");


@@ -495,7 +496,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
return false; return false;
} }
uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data()));
int64_t offset = elem_num * 8;
int64_t offset = elem_num * kOffsetUnit;
uintptr_t hbm_raw_data_base_addr = reinterpret_cast<uintptr_t>(constant->output_addrs[0]) + offset; uintptr_t hbm_raw_data_base_addr = reinterpret_cast<uintptr_t>(constant->output_addrs[0]) + offset;
for (int64_t i = elem_num - 1; i >= 0; --i) { for (int64_t i = elem_num - 1; i >= 0; --i) {
buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]); buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]);


+ 8
- 5
ge/graph/load/new_model_manager/model_manager.cc View File

@@ -50,6 +50,9 @@ const std::string kCmdTypeProfModelSubscribe = "prof_model_subscribe";
const std::string kCmdTypeProfModelUnsubscribe = "prof_model_cancel_subscribe"; const std::string kCmdTypeProfModelUnsubscribe = "prof_model_cancel_subscribe";
const char *const kBatchLoadBuf = "batchLoadsoFrombuf"; const char *const kBatchLoadBuf = "batchLoadsoFrombuf";
const char *const kDeleteCustOp = "deleteCustOp"; const char *const kDeleteCustOp = "deleteCustOp";
const int kTimeSpecNano = 1000000000;
const int kTimeSpecMiro = 1000000;
const int kSessionMaxBias = 100;
struct CustAicpuSoBuf { struct CustAicpuSoBuf {
uint64_t kernelSoBuf; uint64_t kernelSoBuf;
uint32_t kernelSoBufLen; uint32_t kernelSoBufLen;
@@ -337,7 +340,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge


GELOGI("Parse model %u success.", model_id); GELOGI("Parse model %u success.", model_id);


davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 +
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano +
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond
davinci_model->SetProfileTime(MODEL_LOAD_END); davinci_model->SetProfileTime(MODEL_LOAD_END);
} while (0); } while (0);
@@ -1041,12 +1044,12 @@ Status ModelManager::GenSessionId(uint64_t &session_id) {
GELOGE(INTERNAL_ERROR, "Failed to get current time."); GELOGE(INTERNAL_ERROR, "Failed to get current time.");
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
session_id = static_cast<uint64_t>(tv.tv_sec * 1000000 + tv.tv_usec); // 1000000us
session_id = static_cast<uint64_t>(tv.tv_sec * kTimeSpecMiro + tv.tv_usec); // 1000000us


session_id_bias_++; session_id_bias_++;
// max bais 100. // max bais 100.
session_id_bias_ = session_id_bias_ % 100;
session_id = session_id * 100 + session_id_bias_;
session_id_bias_ = session_id_bias_ % kSessionMaxBias;
session_id = session_id * kSessionMaxBias + session_id_bias_;


GELOGD("Generate new session id: %lu.", session_id); GELOGD("Generate new session id: %lu.", session_id);
return SUCCESS; return SUCCESS;
@@ -1117,7 +1120,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model


GELOGI("Parse model %u success.", model_id); GELOGI("Parse model %u success.", model_id);


davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 +
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano +
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond
davinci_model->SetProfileTime(MODEL_LOAD_END); davinci_model->SetProfileTime(MODEL_LOAD_END);




+ 13
- 6
ge/graph/load/new_model_manager/task_info/kernel_task_info.cc View File

@@ -43,6 +43,13 @@ const char *kIsLastNode = "is_last_node";
const char *kIsFirstNode = "is_first_node"; const char *kIsFirstNode = "is_first_node";
const int64_t kCloseSkt = 100; const int64_t kCloseSkt = 100;
const uint32_t kAddrLen = sizeof(void *); const uint32_t kAddrLen = sizeof(void *);
const int kBaseInt = 10;
const int kStrtolFail = 0;
const int kArgsInputDesc = 0;
const int kArgsInputAddr = 1;
const int kArgsOutputDesc = 2;
const int kArgsOutputAddr = 3;
const int kArgsAttrHandle = 4;
} // namespace } // namespace


namespace ge { namespace ge {
@@ -371,7 +378,7 @@ Status KernelTaskInfo::Distribute() {
rtError_t rt_ret = RT_ERROR_NONE; rtError_t rt_ret = RT_ERROR_NONE;
char skt_enable_env[MMPA_MAX_PATH] = { 0x00 }; char skt_enable_env[MMPA_MAX_PATH] = { 0x00 };
INT32 res = mmGetEnv("SKT_ENABLE", skt_enable_env, MMPA_MAX_PATH); INT32 res = mmGetEnv("SKT_ENABLE", skt_enable_env, MMPA_MAX_PATH);
int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, 10) : 0;
int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, kBaseInt) : kStrtolFail;
bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_);
if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) {
GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_);
@@ -749,15 +756,15 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel
return FAILED; return FAILED;
} }
} }
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[0])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsInputDesc])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_descs)); // arg 0 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_descs)); // arg 0
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[1])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsInputAddr])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_addrs)); // arg 1 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_addrs)); // arg 1
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[2])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsOutputDesc])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_descs)); // arg 2 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_descs)); // arg 2
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[3])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsOutputAddr])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_addrs)); // arg 3 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_addrs)); // arg 3
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[4])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsAttrHandle])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.attr_handle)); // arg 4 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.attr_handle)); // arg 4


rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM);


+ 10
- 8
ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc View File

@@ -19,6 +19,8 @@


namespace ge { namespace ge {
namespace skt { namespace skt {
const size_t kFusedKernelMinimumSize = 2;
const size_t kFusedKernelSizeUnit = 2;
SuperKernelFactory &SuperKernelFactory::GetInstance() { SuperKernelFactory &SuperKernelFactory::GetInstance() {
static SuperKernelFactory factory; static SuperKernelFactory factory;
return factory; return factory;
@@ -79,17 +81,17 @@ Status SuperKernelFactory::FuseKernels(const std::vector<void *> &stub_func_list
return FAILED; return FAILED;
} }


if (super_kernel_size < 2) {
if (super_kernel_size < kFusedKernelMinimumSize) {
GELOGW( GELOGW(
"SKT: the number of kernels being fused must be greater than or " "SKT: the number of kernels being fused must be greater than or "
"equal to 2"); "equal to 2");
return FAILED; return FAILED;
} }
GELOGI("SKT: superkernel start fuse, superkernel size %zu.", stub_func_list.size()); GELOGI("SKT: superkernel start fuse, superkernel size %zu.", stub_func_list.size());
const size_t nav_table_len = 2 * stub_func_list.size();
std::unique_ptr<uint64_t[]> nav_table(new (std::nothrow) uint64_t[nav_table_len]);
const size_t nav_table_len = kFusedKernelSizeUnit * stub_func_list.size();
std::unique_ptr<uint64_t[]> nav_table(new(std::nothrow) uint64_t[nav_table_len]);
GE_CHECK_NOTNULL(nav_table); GE_CHECK_NOTNULL(nav_table);
uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t);
uint64_t nav_table_size = kFusedKernelSizeUnit * stub_func_list.size() * sizeof(int64_t);


rtError_t rt_ret; rtError_t rt_ret;
void *hbm_nav_table_addr = nullptr; void *hbm_nav_table_addr = nullptr;
@@ -101,10 +103,10 @@ Status SuperKernelFactory::FuseKernels(const std::vector<void *> &stub_func_list
GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func);
// store two uint64_t address // store two uint64_t address
// address divided by 4 because of 32bits encoding, call offset will *4 when calculating // address divided by 4 because of 32bits encoding, call offset will *4 when calculating
nav_table[i * 2] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func)) / 4;
GELOGD("SKT: CALL offet %lu", nav_table[i * 2]);
nav_table[i * 2 + 1] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_list[i]));
GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]);
nav_table[i * kFusedKernelSizeUnit] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func)) / 4;
GELOGD("SKT: CALL offet %lu", nav_table[i * kFusedKernelSizeUnit]);
nav_table[i * kFusedKernelSizeUnit + 1] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_list[i]));
GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * kFusedKernelSizeUnit + 1]);
} }
rt_ret = rtMalloc(reinterpret_cast<void **>(&hbm_nav_table_addr), nav_table_size, RT_MEMORY_HBM); rt_ret = rtMalloc(reinterpret_cast<void **>(&hbm_nav_table_addr), nav_table_size, RT_MEMORY_HBM);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failed. error: 0x%X", rt_ret); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failed. error: 0x%X", rt_ret);


+ 1
- 1
ge/graph/load/new_model_manager/ts_mem_mall.h View File

@@ -25,7 +25,7 @@
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"


namespace { namespace {
constexpr uint32_t kMaxTsMemBlock = 2 * 1024 * 1024; // Max block 2M
constexpr uint32_t kMaxTsMemBlock = 2097152; // Max block 2M 2 * 1024 * 1024
constexpr uint32_t kTsMemAligment = 64; // Malloc for 64 bits align constexpr uint32_t kTsMemAligment = 64; // Malloc for 64 bits align
constexpr uint32_t kTsMemAlignMask = kTsMemAligment - 1; constexpr uint32_t kTsMemAlignMask = kTsMemAligment - 1;
} }


+ 6
- 6
ge/graph/manager/graph_caching_allocator.cc View File

@@ -25,13 +25,13 @@


namespace ge { namespace ge {
const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize,
8 * kMByteSize,
32 * kMByteSize,
128 * kMByteSize,
kBinSizeUnit8 * kMByteSize,
kBinSizeUnit32 * kMByteSize,
kBinSizeUnit128 * kMByteSize,
kGByteSize, kGByteSize,
4 * kGByteSize,
16 * kGByteSize,
26 * kGByteSize};
kBinSizeUnit4 * kGByteSize,
kBinSizeUnit16 * kGByteSize,
kBinSizeUnit26 * kGByteSize};


static bool BlockComparator(const Block *left, const Block *right) { static bool BlockComparator(const Block *left, const Block *right) {
if (left->size != right->size) { if (left->size != right->size) {


+ 9
- 2
ge/graph/manager/graph_caching_allocator.h View File

@@ -34,10 +34,17 @@


namespace ge { namespace ge {
constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes
constexpr size_t kBinSizeUnit4 = 4;
constexpr size_t kBinSizeUnit8 = 8;
constexpr size_t kBinSizeUnit16 = 16;
constexpr size_t kBinSizeUnit26 = 26;
constexpr size_t kBinSizeUnit32 = 32;
constexpr size_t kBinSizeUnit128 = 128;

constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold
constexpr size_t kKByteSize = 1024; constexpr size_t kKByteSize = 1024;
constexpr size_t kMByteSize = 1024 * 1024;
constexpr size_t kGByteSize = 1024 * 1024 * 1024;
constexpr size_t kMByteSize = 1048576; // 1024 * 1024
constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024


static const uint32_t kNumBins = 8; static const uint32_t kNumBins = 8;




+ 2
- 2
ge/graph/manager/graph_var_manager.cc View File

@@ -280,9 +280,9 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin
return PARAM_INVALID; return PARAM_INVALID;
} }
uint64_t free_size = total_size_ - var_mem_size_; uint64_t free_size = total_size_ - var_mem_size_;
if (free_size < (size + kSessionMemAlignSize * 2)) {
if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) {
GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]", GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]",
size + kSessionMemAlignSize * 2 + var_mem_size_, total_size_);
size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_);
return PARAM_INVALID; return PARAM_INVALID;
} }




+ 1
- 0
ge/graph/manager/graph_var_manager.h View File

@@ -42,6 +42,7 @@ const size_t kGraphMemoryBuffer = 4UL * 1024UL * 1024UL * 1024UL;
const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL;
const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY";
const uint64_t kSessionMemAlignSize = 512; const uint64_t kSessionMemAlignSize = 512;
const size_t kSessionMemAlignUnit = 2;


enum MemStatus { enum MemStatus {
NORMAL = 0, NORMAL = 0,


+ 13
- 6
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -26,6 +26,13 @@
namespace { namespace {
using namespace ge; using namespace ge;
const int kIdentityAnchorIndex = 0; const int kIdentityAnchorIndex = 0;
const size_t kSerialStringVecSize = 4;

const int kCaseReadOnly = 0;
const int kCaseScopeWriteable = 2;
const int kCaseWriteable = 3;
const int kCaseInvalidRWType = 5;

// rw type of input. // rw type of input.
enum class InputRWType { enum class InputRWType {
kReadOnly, // Normal op input only read kReadOnly, // Normal op input only read
@@ -55,7 +62,7 @@ thread_local map<string, NodeInputOutputRWType> node_rwtype_map_;
/// @return rw_type_name /// @return rw_type_name
/// ///
static std::string InputRWTypeToSerialString(InputRWType rw_type) { static std::string InputRWTypeToSerialString(InputRWType rw_type) {
const static char *names[4] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"};
const static char *names[kSerialStringVecSize] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"};
return names[static_cast<int>(rw_type)]; return names[static_cast<int>(rw_type)];
} }


@@ -65,7 +72,7 @@ static std::string InputRWTypeToSerialString(InputRWType rw_type) {
/// @return rw_type_name /// @return rw_type_name
/// ///
static std::string OutputRWTypeToSerialString(OutputRWType rw_type) { static std::string OutputRWTypeToSerialString(OutputRWType rw_type) {
const static char *names[4] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"};
const static char *names[kSerialStringVecSize] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"};
return names[static_cast<int>(rw_type)]; return names[static_cast<int>(rw_type)];
} }


@@ -118,13 +125,13 @@ InputRWType GetInputRwTypeInConflict(const std::set<int> &rw_type_set) {
} }


switch (total_rw_type) { switch (total_rw_type) {
case 0:
case kCaseReadOnly:
return InputRWType::kReadOnly; // all input rw type is readonly return InputRWType::kReadOnly; // all input rw type is readonly
case 2:
case kCaseScopeWriteable:
return InputRWType::kScopeWriteable; // readonly 2 scope_writeable return InputRWType::kScopeWriteable; // readonly 2 scope_writeable
case 3:
case kCaseWriteable:
return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable
case 5:
case kCaseInvalidRWType:
return InputRWType::kInvalidRWType; // writeable 2 scope_writeable return InputRWType::kInvalidRWType; // writeable 2 scope_writeable
default: default:
return InputRWType::kInvalidRWType; return InputRWType::kInvalidRWType;


+ 2
- 1
ge/graph/passes/data_pass.cc View File

@@ -21,6 +21,7 @@


namespace ge { namespace ge {
namespace { namespace {
const int kDataIndexOffset = 2;
Status MappingSubgraphInput(const ComputeGraphPtr &graph, const std::function<int(int data_index)> &input) { Status MappingSubgraphInput(const ComputeGraphPtr &graph, const std::function<int(int data_index)> &input) {
for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
if (node->GetType() != DATA) { if (node->GetType() != DATA) {
@@ -111,7 +112,7 @@ Status ParseSubgraphPostFnWhile(const string &subgraph_name, const ComputeGraphP


Status ParseSubgraphPostFnFor(const string &subgraph_name, const ComputeGraphPtr &graph) { Status ParseSubgraphPostFnFor(const string &subgraph_name, const ComputeGraphPtr &graph) {
return MappingSubgraphIndex(graph, return MappingSubgraphIndex(graph,
[](int data_index) { return (data_index == 0) ? 0 : data_index + 2; },
[](int data_index) { return (data_index == 0) ? 0 : data_index + kDataIndexOffset; },
[](int retval_index) { return retval_index; }); [](int retval_index) { return retval_index; });
} }




+ 2
- 1
ge/graph/passes/for_pass.cc View File

@@ -37,6 +37,7 @@ namespace {
const uint32_t kSubgraphLoopVarInputIndex = 0; const uint32_t kSubgraphLoopVarInputIndex = 0;
const uint32_t kSubgraphInputIndex = 1; const uint32_t kSubgraphInputIndex = 1;
const uint32_t kWhileOutputIndex = 5; const uint32_t kWhileOutputIndex = 5;
const size_t kIDiffValue = 2;
const std::string kAbs = "Abs"; const std::string kAbs = "Abs";
} }


@@ -694,7 +695,7 @@ Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) {
} else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) { } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) {
continue; continue;
} else { } else {
input_mapping[i] = i - 2;
input_mapping[i] = i - kIDiffValue;
} }
} }
for_body->UpdateInputMapping(input_mapping); for_body->UpdateInputMapping(input_mapping);


+ 3
- 1
ge/graph/passes/mark_agnostic_pass.cc View File

@@ -19,6 +19,8 @@
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"


namespace ge { namespace ge {
const size_t kTwoInputNodesSize = 2;

Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
auto node_type = NodeUtils::GetNodeType(*node); auto node_type = NodeUtils::GetNodeType(*node);
@@ -52,7 +54,7 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
/// Enter-----------+ /// Enter-----------+
/// +-> Merge /// +-> Merge
/// NextIteration---+ /// NextIteration---+
if (input_nodes.size() == 2) {
if (input_nodes.size() == kTwoInputNodesSize) {
if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) { if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) {
continue; continue;
} }


+ 4
- 2
ge/graph/passes/merge_pass.cc View File

@@ -29,6 +29,8 @@


namespace ge { namespace ge {
const int kValueIndexOutputIndex = 1; const int kValueIndexOutputIndex = 1;
const size_t kCaseNoInput = 0;
const size_t kCaseOneInput = 1;


Status MergePass::Run(NodePtr &node) { Status MergePass::Run(NodePtr &node) {
GELOGD("MergePass running"); GELOGD("MergePass running");
@@ -50,7 +52,7 @@ Status MergePass::Run(NodePtr &node) {


const auto &in_data_nodes = node->GetInDataNodes(); const auto &in_data_nodes = node->GetInDataNodes();
switch (in_data_nodes.size()) { switch (in_data_nodes.size()) {
case 0: {
case kCaseNoInput: {
/// Case A: input_count = 0, the output of merge node is inactive as well /// Case A: input_count = 0, the output of merge node is inactive as well
/// In which case the output branch can be removed /// In which case the output branch can be removed
/// until another merge node is met /// until another merge node is met
@@ -65,7 +67,7 @@ Status MergePass::Run(NodePtr &node) {
} }
return ret; return ret;
} }
case 1: { // Case B: input_count = 1, the merge node can be optimized out
case kCaseOneInput: { // Case B: input_count = 1, the merge node can be optimized out
std::vector<int> merge_io_map = {PassUtils::GetUniqueInDataAnchorIndex(node), -1}; std::vector<int> merge_io_map = {PassUtils::GetUniqueInDataAnchorIndex(node), -1};
if (merge_io_map[0] != -1 && IsNeedChangeIndexToConstant(node)) { if (merge_io_map[0] != -1 && IsNeedChangeIndexToConstant(node)) {
int index = merge_io_map[0]; int index = merge_io_map[0];


+ 8
- 4
ge/host_kernels/gather_v2_kernel.cc View File

@@ -40,6 +40,10 @@ const size_t kGatherV2InpotNum = 3;
const size_t kMaxIndicatesDims = 1; // only support scalar and 1 dims indicates_ const size_t kMaxIndicatesDims = 1; // only support scalar and 1 dims indicates_
const std::set<DataType> supported_type = {DT_FLOAT16, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT16, DT_INT32, const std::set<DataType> supported_type = {DT_FLOAT16, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT16, DT_INT32,
DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64}; DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64};
const int64_t DIM_AXIS_0 = 0;
const int64_t DIM_AXIS_1 = 1;
const int64_t DIM_AXIS_2 = 2;
const int64_t DIM_AXIS_3 = 3;
} // namespace } // namespace
template <typename T> template <typename T>
Status GatherV2Kernel::ProcessAxis0(ConstGeTensorPtr tensor_x, GeTensorPtr output) { Status GatherV2Kernel::ProcessAxis0(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
@@ -191,16 +195,16 @@ Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x


Status ret = SUCCESS; Status ret = SUCCESS;
switch (axis) { switch (axis) {
case 0:
case DIM_AXIS_0:
ret = ProcessAxis0<T>(tensor_x, output); ret = ProcessAxis0<T>(tensor_x, output);
break; break;
case 1:
case DIM_AXIS_1:
ret = ProcessAxis1<T>(tensor_x, output); ret = ProcessAxis1<T>(tensor_x, output);
break; break;
case 2:
case DIM_AXIS_2:
ret = ProcessAxis2<T>(tensor_x, output); ret = ProcessAxis2<T>(tensor_x, output);
break; break;
case 3:
case DIM_AXIS_3:
ret = ProcessAxis3<T>(tensor_x, output); ret = ProcessAxis3<T>(tensor_x, output);
break; break;
default: default:


+ 6
- 3
ge/host_kernels/range_kernel.cc View File

@@ -32,6 +32,9 @@ namespace ge {
namespace { namespace {
constexpr size_t kRangeInputNum = 3; constexpr size_t kRangeInputNum = 3;
constexpr uint32_t kRangeDimNum = 0; constexpr uint32_t kRangeDimNum = 0;
constexpr size_t kStartIndex = 0;
constexpr size_t kLimitIndex = 1;
constexpr size_t kDeltaIndex = 2;
const std::set<DataType> kRangeSupportedType = {DT_INT32, DT_FLOAT}; const std::set<DataType> kRangeSupportedType = {DT_INT32, DT_FLOAT};
} // namespace } // namespace


@@ -53,9 +56,9 @@ Status RangeKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<Const
return MEMALLOC_FAILED; return MEMALLOC_FAILED;
} }


ConstGeTensorPtr start = input.at(0);
ConstGeTensorPtr limit = input.at(1);
ConstGeTensorPtr delta = input.at(2);
ConstGeTensorPtr start = input.at(kStartIndex);
ConstGeTensorPtr limit = input.at(kLimitIndex);
ConstGeTensorPtr delta = input.at(kDeltaIndex);
DataType data_type = delta->GetTensorDesc().GetDataType(); DataType data_type = delta->GetTensorDesc().GetDataType();
if (data_type == DT_FLOAT) { if (data_type == DT_FLOAT) {
if (GetRange(*reinterpret_cast<const float *>(start->GetData().data()), if (GetRange(*reinterpret_cast<const float *>(start->GetData().data()),


+ 3
- 1
ge/hybrid/common/npu_memory_allocator.cc View File

@@ -23,6 +23,8 @@


namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
const size_t kPaddingUnit = 2;

size_t kMaxHbmMemorySize = 1024UL * 1024UL * 1024UL * 1024UL; // 1024G size_t kMaxHbmMemorySize = 1024UL * 1024UL * 1024UL * 1024UL; // 1024G


std::map<uint32_t, std::unique_ptr<NpuMemoryAllocator>> NpuMemoryAllocator::allocators_; std::map<uint32_t, std::unique_ptr<NpuMemoryAllocator>> NpuMemoryAllocator::allocators_;
@@ -77,7 +79,7 @@ void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) {
} }
} }
// padding up to multiple of padding, and add extra padding // padding up to multiple of padding, and add extra padding
allocate_size = (size + 2 * padding - 1) / padding * padding;
allocate_size = (size + kPaddingUnit * padding - 1) / padding * padding;
GELOGD("Padding size %ld by %d. final size = %zu.", size, padding, allocate_size); GELOGD("Padding size %ld by %d. final size = %zu.", size, padding, allocate_size);
buffer = MemManager::Instance() buffer = MemManager::Instance()
.CachingInstance(RT_MEMORY_HBM) .CachingInstance(RT_MEMORY_HBM)


+ 1
- 1
ge/hybrid/executor/node_done_manager.cc View File

@@ -21,7 +21,7 @@
namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
namespace { namespace {
constexpr int kDefaultWaitTimeoutInSec = 60 * 10;
constexpr int kDefaultWaitTimeoutInSec = 600;
} }
bool NodeDoneManager::Cond::Await() { bool NodeDoneManager::Cond::Await() {
std::unique_lock<std::mutex> lk(cond_mu_); std::unique_lock<std::mutex> lk(cond_mu_);


+ 1
- 1
ge/offline/main.cc View File

@@ -68,7 +68,7 @@ const char *const kModeSupport = "only support 0(model to framework model), "
const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow) 5(Onnx)"; const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow) 5(Onnx)";


// limit available mem size 2G // limit available mem size 2G
const long kMinAvailableMem = 2 * 1024 * 1024;
const long kMinAvailableMem = 2097152; // 2 * 1024 * 1024


DEFINE_string(model, "", "The model file."); DEFINE_string(model, "", "The model file.");
DEFINE_string(output, "", "The output file path&name."); DEFINE_string(output, "", "The output file path&name.");


+ 6
- 3
ge/session/omg.cc View File

@@ -68,6 +68,9 @@ const std::string kScopeIdAttr = "fusion_scope";
const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\"";
const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8";
const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes.";
const size_t kNodeNameIndex = 0;
const size_t kIndexStrIndex = 1;
const size_t kDTValueIndex = 2;
} // namespace } // namespace


// When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored
@@ -381,14 +384,14 @@ Status ParseOutputType(const std::string &output_type, std::map<std::string, vec
return domi::FAILED; return domi::FAILED;
} }
ge::DataType tmp_dt; ge::DataType tmp_dt;
std::string node_name = StringUtils::Trim(node_index_type_v[0]);
std::string index_str = StringUtils::Trim(node_index_type_v[1]);
std::string node_name = StringUtils::Trim(node_index_type_v[kNodeNameIndex]);
std::string index_str = StringUtils::Trim(node_index_type_v[kIndexStrIndex]);
int32_t index; int32_t index;
if (StringToInt(index_str, index) != SUCCESS) { if (StringToInt(index_str, index) != SUCCESS) {
GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str()); GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str());
return domi::FAILED; return domi::FAILED;
} }
std::string dt_value = StringUtils::Trim(node_index_type_v[2]);
std::string dt_value = StringUtils::Trim(node_index_type_v[kDTValueIndex]);
auto it = output_type_str_to_datatype.find(dt_value); auto it = output_type_str_to_datatype.find(dt_value);
if (it == output_type_str_to_datatype.end()) { if (it == output_type_str_to_datatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},


+ 2
- 1
ge/single_op/single_op.cc View File

@@ -30,9 +30,10 @@
namespace ge { namespace ge {
namespace { namespace {
const size_t kDataMemAlignSize = 32; const size_t kDataMemAlignSize = 32;
const size_t kDataMemAlignUnit = 2;


size_t GetAlignedSize(size_t size) { size_t GetAlignedSize(size_t size) {
size_t aligned_size = (size + 2 * kDataMemAlignSize - 1) / kDataMemAlignSize * kDataMemAlignSize;
size_t aligned_size = (size + kDataMemAlignUnit * kDataMemAlignSize - 1) / kDataMemAlignSize * kDataMemAlignSize;
return aligned_size; return aligned_size;
} }




+ 4
- 4
inc/framework/common/fmk_error_codes.h View File

@@ -23,10 +23,6 @@
#include "framework/common/fmk_types.h" #include "framework/common/fmk_types.h"
#include "register/register_error_codes.h" #include "register/register_error_codes.h"


#define MODID_OMG 1 // OMG module ID
#define MODID_OME 2 // OME module ID
#define MODID_CALIBRATION 3 // Calibration module ID

// Each module uses the following four macros to define error codes: // Each module uses the following four macros to define error codes:
#define DECLARE_ERRORNO_OMG(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OMG, name, value) #define DECLARE_ERRORNO_OMG(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OMG, name, value)
#define DECLARE_ERRORNO_OME(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OME, name, value) #define DECLARE_ERRORNO_OME(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OME, name, value)
@@ -37,6 +33,10 @@
// Interface for Obtaining Error Code Description // Interface for Obtaining Error Code Description
#define GET_ERRORNO_STR(value) domi::StatusFactory::Instance()->GetErrDesc(value) #define GET_ERRORNO_STR(value) domi::StatusFactory::Instance()->GetErrDesc(value)


const int MODID_OMG = 1; // OMG module ID
const int MODID_OME = 2; // OME module ID
const int MODID_CALIBRATION = 3; // Calibration module ID

namespace domi { namespace domi {
class StatusFactory { class StatusFactory {
public: public:


Loading…
Cancel
Save