Browse Source

!45 Synchronize Ascend software suite 18 Jul 2020

Merge pull request !45 from yanghaoran/master
tags/v0.6.0-beta
mindspore-ci-bot Gitee 4 years ago
parent
commit
103f2d1019
100 changed files with 2241 additions and 877 deletions
  1. +11
    -0
      inc/common/opskernel/ge_task_info.h
  2. +1
    -0
      inc/common/util/compress/compress.h
  3. +33
    -0
      inc/common/util/compress/compress_weight.h
  4. +2
    -2
      inc/common/util/platform_info.h
  5. +8
    -0
      inc/external/ge/ge_api_types.h
  6. +2
    -1
      inc/external/graph/types.h
  7. +2
    -0
      inc/external/register/register.h
  8. +0
    -24
      inc/framework/common/debug/ge_log.h
  9. +19
    -36
      inc/framework/common/debug/log.h
  10. +5
    -3
      inc/framework/common/ge_types.h
  11. +0
    -2
      inc/framework/common/helper/model_helper.h
  12. +7
    -0
      inc/framework/common/types.h
  13. +2
    -2
      inc/framework/executor/ge_executor.h
  14. +6
    -1
      inc/framework/ge_runtime/model_runner.h
  15. +80
    -55
      inc/framework/ge_runtime/task_info.h
  16. +1
    -0
      inc/framework/generator/ge_generator.h
  17. +4
    -3
      inc/framework/omg/omg.h
  18. +2
    -0
      inc/framework/omg/omg_inner_types.h
  19. +9
    -1
      inc/graph/compute_graph.h
  20. +22
    -1
      inc/graph/debug/ge_attr_define.h
  21. +0
    -1
      inc/graph/detail/attributes_holder.h
  22. +1
    -0
      inc/graph/ge_context.h
  23. +5
    -2
      inc/graph/ge_tensor.h
  24. +0
    -1
      inc/graph/model_serialize.h
  25. +4
    -0
      inc/graph/op_desc.h
  26. +11
    -4
      inc/graph/utils/graph_utils.h
  27. +25
    -0
      inc/graph/utils/node_utils.h
  28. +1
    -0
      inc/graph/utils/tensor_adapter.h
  29. +1
    -0
      inc/graph/utils/tensor_utils.h
  30. +1
    -0
      src/common/graph/CMakeLists.txt
  31. +13
    -0
      src/common/graph/compute_graph.cc
  32. +4
    -0
      src/common/graph/debug/ge_op_types.h
  33. +69
    -21
      src/common/graph/format_refiner.cc
  34. +4
    -4
      src/common/graph/format_refiner.h
  35. +22
    -1
      src/common/graph/ge_attr_define.cc
  36. +11
    -0
      src/common/graph/ge_tensor.cc
  37. +1
    -1
      src/common/graph/graph.cc
  38. +108
    -1
      src/common/graph/graph.mk
  39. +12
    -13
      src/common/graph/model_serialize.cc
  40. +34
    -13
      src/common/graph/node.cc
  41. +40
    -13
      src/common/graph/op_desc.cc
  42. +33
    -40
      src/common/graph/operator.cc
  43. +2
    -0
      src/common/graph/option/ge_context.cc
  44. +4
    -0
      src/common/graph/ref_relation.cc
  45. +156
    -16
      src/common/graph/shape_refiner.cc
  46. +5
    -5
      src/common/graph/utils/ge_ir_utils.h
  47. +55
    -24
      src/common/graph/utils/graph_utils.cc
  48. +153
    -18
      src/common/graph/utils/node_utils.cc
  49. +46
    -20
      src/common/graph/utils/op_desc_utils.cc
  50. +6
    -2
      src/common/graph/utils/tensor_utils.cc
  51. +2
    -1
      src/common/graph/utils/type_utils.cc
  52. +5
    -2
      src/ge/CMakeLists.txt
  53. +14
    -43
      src/ge/client/ge_api.cc
  54. +6
    -5
      src/ge/common/convert/pb2json.cc
  55. +0
    -1
      src/ge/common/formats/format_transfers/datatype_transfer.cc
  56. +0
    -1
      src/ge/common/formats/format_transfers/datatype_transfer.h
  57. +0
    -1
      src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc
  58. +0
    -1
      src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc
  59. +1
    -1
      src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc
  60. +55
    -40
      src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc
  61. +1
    -3
      src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc
  62. +0
    -2
      src/ge/common/formats/utils/formats_definitions.h
  63. +0
    -1
      src/ge/common/formats/utils/formats_trans_utils.h
  64. +1
    -1
      src/ge/common/fp16_t.h
  65. +81
    -0
      src/ge/common/ge/op_tiling_manager.cc
  66. +15
    -6
      src/ge/common/ge/op_tiling_manager.h
  67. +2
    -101
      src/ge/common/helper/model_helper.cc
  68. +1
    -1
      src/ge/common/math/fp16_math.h
  69. +0
    -2
      src/ge/common/math_util.h
  70. +3
    -4
      src/ge/common/model_saver.cc
  71. +13
    -7
      src/ge/common/profiling/profiling_manager.cc
  72. +200
    -120
      src/ge/common/properties_manager.cc
  73. +47
    -21
      src/ge/common/properties_manager.h
  74. +0
    -1
      src/ge/common/tbe_kernel_store.h
  75. +102
    -93
      src/ge/common/types.cc
  76. +31
    -35
      src/ge/common/util.cc
  77. +21
    -1
      src/ge/engine_manager/dnnengine_manager.cc
  78. +3
    -0
      src/ge/engine_manager/dnnengine_manager.h
  79. +1
    -1
      src/ge/executor/CMakeLists.txt
  80. +3
    -4
      src/ge/executor/ge_executor.cc
  81. +2
    -1
      src/ge/executor/module.mk
  82. +30
    -4
      src/ge/ge_inference.mk
  83. +36
    -3
      src/ge/ge_runner.mk
  84. +31
    -0
      src/ge/ge_runtime/model_runner.cc
  85. +1
    -1
      src/ge/ge_runtime/output.cc
  86. +57
    -24
      src/ge/ge_runtime/runtime_model.cc
  87. +7
    -2
      src/ge/ge_runtime/runtime_model.h
  88. +9
    -5
      src/ge/ge_runtime/task/aicpu_task.cc
  89. +6
    -0
      src/ge/ge_runtime/task/aicpu_task.h
  90. +0
    -1
      src/ge/ge_runtime/task/hccl_task.cc
  91. +70
    -0
      src/ge/ge_runtime/task/label_goto_task.cc
  92. +41
    -0
      src/ge/ge_runtime/task/label_goto_task.h
  93. +70
    -0
      src/ge/ge_runtime/task/label_set_task.cc
  94. +41
    -0
      src/ge/ge_runtime/task/label_set_task.h
  95. +131
    -0
      src/ge/ge_runtime/task/label_switch_task.cc
  96. +44
    -0
      src/ge/ge_runtime/task/label_switch_task.h
  97. +1
    -1
      src/ge/ge_runtime/task/stream_switch_task.cc
  98. +6
    -0
      src/ge/ge_runtime/task/task.h
  99. +3
    -4
      src/ge/ge_runtime/task/tbe_task.cc
  100. +4
    -0
      src/ge/ge_runtime/task/tbe_task.h

+ 11
- 0
inc/common/opskernel/ge_task_info.h View File

@@ -52,5 +52,16 @@ struct GETaskInfo {

std::vector<GETaskKernelHcclInfo> kernelHcclInfo;
};

struct HcomOpertion {
std::string hcclType;
void *inputPtr;
void *outputPtr;
uint64_t count;
int32_t dataType;
int32_t opType;
int32_t root;
};

} // namespace ge
#endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_

+ 1
- 0
inc/common/util/compress/compress.h View File

@@ -28,6 +28,7 @@ struct CompressConfig {
size_t channel; // channels of L2 or DDR. For load balance
size_t fractalSize; // size of compressing block
bool isTight; // whether compose compressed data tightly
size_t init_offset;
};

CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output,


+ 33
- 0
inc/common/util/compress/compress_weight.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMPRESS_WEIGHT_H
#define COMPRESS_WEIGHT_H

#include "compress.h"

const int SHAPE_SIZE_WEIGHT = 4;

struct CompressOpConfig {
int64_t wShape[SHAPE_SIZE_WEIGHT];
size_t compressTilingK;
size_t compressTilingN;
struct CompressConfig compressConfig;
};

extern "C" CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer, char *const infoBuffer,
CompressOpConfig *const param);
#endif // COMPRESS_WEIGHT_H

+ 2
- 2
inc/common/util/platform_info.h View File

@@ -27,7 +27,6 @@ using std::string;
using std::vector;

namespace fe {

class PlatformInfoManager {
public:
PlatformInfoManager(const PlatformInfoManager &) = delete;
@@ -39,6 +38,8 @@ class PlatformInfoManager {

uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo);

uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo);

void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo);

private:
@@ -94,6 +95,5 @@ class PlatformInfoManager {
map<string, PlatformInfo> platformInfoMap_;
OptionalInfo optiCompilationInfo_;
};

} // namespace fe
#endif

+ 8
- 0
inc/external/ge/ge_api_types.h View File

@@ -44,8 +44,12 @@ const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump";
const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath";
const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep";
const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode";
const char *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug";
const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode";
const char *const OPTION_EXEC_OP_DEBUG_LEVEL = "ge.exec.opDebugLevel";
const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild";
const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath";
const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses";
// profiling flag
const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode";
const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions";
@@ -219,6 +223,10 @@ const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream";
// Configure input fp16 nodes
const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16";

// Configure debug level, its value should be 0(default), 1 or 2.
// 0: close debug; 1: open TBE compiler; 2: open ccec compiler
const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel";

// Graph run mode
enum GraphRunMode { PREDICTION = 0, TRAIN };



+ 2
- 1
inc/external/graph/types.h View File

@@ -145,7 +145,8 @@ enum Format {
FORMAT_FRACTAL_ZN_LSTM,
FORMAT_FRACTAL_Z_G,
FORMAT_RESERVED,
FORMAT_ALL
FORMAT_ALL,
FORMAT_NULL
};

// for unknown shape op type


+ 2
- 0
inc/external/register/register.h View File

@@ -98,6 +98,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {

OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type);

OpRegistrationData &InputReorderVector(const vector<int> &input_order);

domi::ImplyType GetImplyType() const;
std::string GetOmOptype() const;
std::set<std::string> GetOriginOpTypeSet() const;


+ 0
- 24
inc/framework/common/debug/ge_log.h View File

@@ -51,30 +51,6 @@ inline pid_t GetTid() {
return tid;
}

#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap()

#define GE_TIMESTAMP_END(stage, stage_name) \
do { \
uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \
(endUsec_##stage - startUsec_##stage)); \
} while (0);

#define GE_TIMESTAMP_CALLNUM_START(stage) \
uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \
uint64_t call_num_of##stage = 0; \
uint64_t time_of##stage = 0

#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap())

#define GE_TIMESTAMP_ADD(stage) \
time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \
call_num_of##stage++

#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \
call_num_of##stage)

#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \
dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \
((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__)


+ 19
- 36
inc/framework/common/debug/log.h View File

@@ -19,15 +19,12 @@

#include <string>

#include "cce/cce_def.hpp"
#include "runtime/rt.h"
#include "common/string_util.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "ge/ge_api_error_codes.h"

using cce::CC_STATUS_SUCCESS;
using cce::ccStatus_t;

#if !defined(__ANDROID__) && !defined(ANDROID)
#define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__)
#else
@@ -102,17 +99,13 @@ using cce::ccStatus_t;
} while (0);

// If expr is not true, print the log and return the specified status
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
std::string msg; \
(void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \
(void)msg.append( \
ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \
DOMI_LOGE("%s", msg.c_str()); \
return _status; \
} \
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
GELOGE(_status, __VA_ARGS__); \
return _status; \
} \
} while (0);

// If expr is not true, print the log and return the specified status
@@ -132,7 +125,7 @@ using cce::ccStatus_t;
DOMI_LOGE(__VA_ARGS__); \
exec_expr; \
} \
};
}

// If expr is not true, print the log and execute a custom statement
#define GE_CHK_BOOL_EXEC_WARN(expr, exec_expr, ...) \
@@ -142,7 +135,7 @@ using cce::ccStatus_t;
GELOGW(__VA_ARGS__); \
exec_expr; \
} \
};
}
// If expr is not true, print the log and execute a custom statement
#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \
{ \
@@ -151,7 +144,7 @@ using cce::ccStatus_t;
GELOGI(__VA_ARGS__); \
exec_expr; \
} \
};
}

// If expr is not true, print the log and execute a custom statement
#define GE_CHK_BOOL_TRUE_EXEC_INFO(expr, exec_expr, ...) \
@@ -161,7 +154,7 @@ using cce::ccStatus_t;
GELOGI(__VA_ARGS__); \
exec_expr; \
} \
};
}

// If expr is true, print logs and execute custom statements
#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \
@@ -171,7 +164,7 @@ using cce::ccStatus_t;
DOMI_LOGE(__VA_ARGS__); \
exec_expr; \
} \
};
}
// If expr is true, print the Information log and execute a custom statement
#define GE_CHK_TRUE_EXEC_INFO(expr, exec_expr, ...) \
{ \
@@ -180,7 +173,7 @@ using cce::ccStatus_t;
GELOGI(__VA_ARGS__); \
exec_expr; \
} \
};
}

// If expr is not SUCCESS, print the log and execute the expression + return
#define GE_CHK_BOOL_TRUE_RET_VOID(expr, exec_expr, ...) \
@@ -191,7 +184,7 @@ using cce::ccStatus_t;
exec_expr; \
return; \
} \
};
}

// If expr is not SUCCESS, print the log and execute the expression + return _status
#define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \
@@ -202,7 +195,7 @@ using cce::ccStatus_t;
exec_expr; \
return _status; \
} \
};
}

// If expr is not true, execute a custom statement
#define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \
@@ -211,7 +204,7 @@ using cce::ccStatus_t;
if (!b) { \
exec_expr; \
} \
};
}

// -----------------runtime related macro definitions-------------------------------
// If expr is not RT_ERROR_NONE, print the log
@@ -231,7 +224,7 @@ using cce::ccStatus_t;
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \
exec_expr; \
} \
};
}

// If expr is not RT_ERROR_NONE, print the log and return
#define GE_CHK_RT_RET(expr) \
@@ -243,23 +236,13 @@ using cce::ccStatus_t;
} \
} while (0);

// ------------------------cce related macro definitions----------------------------
// If expr is not CC_STATUS_SUCCESS, print the log
#define GE_CHK_CCE(expr) \
do { \
ccStatus_t _cc_ret = (expr); \
if (_cc_ret != CC_STATUS_SUCCESS) { \
DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \
} \
} while (0);

// If expr is true, execute exec_expr without printing logs
#define GE_IF_BOOL_EXEC(expr, exec_expr) \
{ \
if (expr) { \
exec_expr; \
} \
};
}

// If make_shared is abnormal, print the log and execute the statement
#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \


+ 5
- 3
inc/framework/common/ge_types.h View File

@@ -54,9 +54,9 @@ const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM";
struct DataBuffer {
public:
void *data; // Data address
uint32_t length; // Data length
uint64_t length; // Data length
bool isDataSupportMemShare = false;
DataBuffer(void *dataIn, uint32_t len, bool isSupportMemShare)
DataBuffer(void *dataIn, uint64_t len, bool isSupportMemShare)
: data(dataIn), length(len), isDataSupportMemShare(isSupportMemShare) {}

DataBuffer() : data(nullptr), length(0), isDataSupportMemShare(false) {}
@@ -106,7 +106,7 @@ struct ShapeDescription {
// Definition of input and output description information
struct InputOutputDescInfo {
std::string name;
uint32_t size;
uint64_t size;
uint32_t data_type;
ShapeDescription shape_info;
};
@@ -231,6 +231,7 @@ struct Options {

// Profiling info of task
struct TaskDescInfo {
std::string model_name;
std::string op_name;
uint32_t block_dim;
uint32_t task_id;
@@ -239,6 +240,7 @@ struct TaskDescInfo {

// Profiling info of graph
struct ComputeGraphDescInfo {
std::string model_name;
std::string op_name;
std::string op_type;
std::vector<Format> input_format;


+ 0
- 2
inc/framework/common/helper/model_helper.h View File

@@ -44,8 +44,6 @@ class ModelHelper {
void SetSaveMode(bool val) { is_offline_ = val; }
bool GetSaveMode(void) const { return is_offline_; }

static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model);
static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr);
Status GetBaseNameFromFileName(const std::string& file_name, std::string& base_name);
Status GetModelNameFromMergedGraphName(const std::string& graph_name, std::string& model_name);



+ 7
- 0
inc/framework/common/types.h View File

@@ -48,6 +48,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_S
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_AICORE;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ATOMIC;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ALL;

// Supported public properties name
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time
@@ -335,6 +338,7 @@ REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell");
REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext");
REGISTER_OPTYPE_DECLARE(INITDATA, "InitData");
REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape")
REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity");

// ANN dedicated operator
REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean");
@@ -631,6 +635,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_N

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_END_GRAPH;

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_OP_DEBUG;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_OP_DEBUG;

// convolution node type
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_CONVOLUTION;
// adds a convolutional node name for the hard AIPP


+ 2
- 2
inc/framework/executor/ge_executor.h View File

@@ -21,12 +21,12 @@
#include <string>
#include <vector>

#include "common/dynamic_aipp.h"
#include "common/ge_inner_error_codes.h"
#include "common/ge_types.h"
#include "common/types.h"
#include "graph/tensor.h"
#include "runtime/base.h"
#include "common/dynamic_aipp.h"

namespace ge {
class ModelListenerAdapter;
@@ -62,7 +62,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {

// Get input and output descriptor
ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc,
std::vector<ge::TensorDesc> &output_desc);
std::vector<ge::TensorDesc> &output_desc, bool new_model_desc = false);

///
/// @ingroup ge


+ 6
- 1
inc/framework/ge_runtime/model_runner.h View File

@@ -28,16 +28,21 @@
namespace ge {
namespace model_runner {
class RuntimeModel;
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class ModelRunner {
public:
static ModelRunner &Instance();

bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener);
bool LoadModelComplete(uint32_t model_id);

const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;

const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;

const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const;

bool UnloadModel(uint32_t model_id);

bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);


+ 80
- 55
inc/framework/ge_runtime/task_info.h View File

@@ -21,6 +21,7 @@
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cce/taskdown_api.h"
@@ -52,21 +53,27 @@ class TaskInfo {
virtual ~TaskInfo() {}
uint32_t stream_id() const { return stream_id_; }
TaskInfoType type() const { return type_; }
std::string op_name() const { return op_name_; }
bool dump_flag() const { return dump_flag_; }

protected:
TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {}
TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
: op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}

private:
std::string op_name_;
uint32_t stream_id_;
TaskInfoType type_;
bool dump_flag_;
};

class CceTaskInfo : public TaskInfo {
public:
CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc,
const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable)
: TaskInfo(stream_id, TaskInfoType::CCE),
CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
const std::vector<uint8_t> &args_offset, bool is_flowtable)
: TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
ctx_(ctx),
stub_func_(stub_func),
block_dim_(block_dim),
@@ -102,11 +109,11 @@ class CceTaskInfo : public TaskInfo {

class TbeTaskInfo : public TaskInfo {
public:
TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args,
uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size,
const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs)
: TaskInfo(stream_id, TaskInfoType::TBE),
TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
stub_func_(stub_func),
block_dim_(block_dim),
args_(args),
@@ -153,9 +160,10 @@ class TbeTaskInfo : public TaskInfo {

class AicpuTaskInfo : public TaskInfo {
public:
AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def,
const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs)
: TaskInfo(stream_id, TaskInfoType::AICPU),
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
const std::string &node_def, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
so_name_(so_name),
kernel_name_(kernel_name),
node_def_(node_def),
@@ -177,37 +185,45 @@ class AicpuTaskInfo : public TaskInfo {
std::vector<void *> output_data_addrs_;
};

class LabelTaskInfo : public TaskInfo {
class LabelSetTaskInfo : public TaskInfo {
public:
LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
~LabelSetTaskInfo() override {}
uint32_t label_id() const { return label_id_; }

protected:
LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id)
: TaskInfo(stream_id, type), label_id_(label_id) {}
virtual ~LabelTaskInfo() override {}

private:
uint32_t label_id_;
};

class LabelSetTaskInfo : public LabelTaskInfo {
class LabelGotoTaskInfo : public TaskInfo {
public:
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {}
~LabelSetTaskInfo() override {}
LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
~LabelGotoTaskInfo() override {}
uint32_t label_id() const { return label_id_; }

private:
uint32_t label_id_;
};

class LabelSwitchTaskInfo : public LabelTaskInfo {
class LabelSwitchTaskInfo : public TaskInfo {
public:
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {}
LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
const std::vector<uint32_t> &label_list, void *cond)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
label_size_(label_size),
label_list_(label_list),
cond_(cond) {}
~LabelSwitchTaskInfo() override {}
};
uint32_t label_size() { return label_size_; };
const std::vector<uint32_t> &label_list() { return label_list_; };
void *cond() { return cond_; };

class LabelGotoTaskInfo : public LabelTaskInfo {
public:
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {}
~LabelGotoTaskInfo() override {}
private:
uint32_t label_size_;
std::vector<uint32_t> label_list_;
void *cond_;
};

class EventTaskInfo : public TaskInfo {
@@ -215,8 +231,8 @@ class EventTaskInfo : public TaskInfo {
uint32_t event_id() const { return event_id_; }

protected:
EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(stream_id, type), event_id_(event_id) {}
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
virtual ~EventTaskInfo() override {}

uint32_t event_id_;
@@ -224,39 +240,41 @@ class EventTaskInfo : public TaskInfo {

class EventRecordTaskInfo : public EventTaskInfo {
public:
EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
~EventRecordTaskInfo() override {}
};

class EventWaitTaskInfo : public EventTaskInfo {
public:
EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
~EventWaitTaskInfo() override {}
};

class FusionStartTaskInfo : public TaskInfo {
public:
explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {}
explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
~FusionStartTaskInfo() override {}
};

class FusionEndTaskInfo : public TaskInfo {
public:
explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {}
explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
~FusionEndTaskInfo() override {}
};

class HcclTaskInfo : public TaskInfo {
public:
HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr,
void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
int64_t op_type, int64_t data_type, std::function<bool(void *, void *)> hcom_bind_model,
std::function<bool(void *)> hcom_unbind_model,
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task)
: TaskInfo(stream_id, TaskInfoType::HCCL),
int64_t op_type, int64_t data_type, const std::string &group,
std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
hccl_type_(hccl_type),
input_data_addr_(input_data_addr),
output_data_addr_(output_data_addr),
@@ -269,6 +287,7 @@ class HcclTaskInfo : public TaskInfo {
root_id_(root_id),
op_type_(op_type),
data_type_(data_type),
group_(group),
hcom_bind_model_(hcom_bind_model),
hcom_unbind_model_(hcom_unbind_model),
hcom_distribute_task_(hcom_distribute_task) {}
@@ -286,6 +305,7 @@ class HcclTaskInfo : public TaskInfo {
int64_t root_id() const { return root_id_; }
int64_t op_type() const { return op_type_; }
int64_t data_type() const { return data_type_; }
const std::string &group() const { return group_; }
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {
@@ -305,6 +325,7 @@ class HcclTaskInfo : public TaskInfo {
int64_t root_id_;
int64_t op_type_;
int64_t data_type_;
std::string group_;
std::function<bool(void *, void *)> hcom_bind_model_;
std::function<bool(void *)> hcom_unbind_model_;
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
@@ -312,8 +333,11 @@ class HcclTaskInfo : public TaskInfo {

class ProfilerTraceTaskInfo : public TaskInfo {
public:
ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {}
ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
log_id_(log_id),
notify_(notify),
flat_(flat) {}
~ProfilerTraceTaskInfo() override {}

uint64_t log_id() const { return log_id_; }
@@ -328,8 +352,9 @@ class ProfilerTraceTaskInfo : public TaskInfo {

class MemcpyAsyncTaskInfo : public TaskInfo {
public:
MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind)
: TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC),
MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
uint64_t count, uint32_t kind, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
dst_(dst),
dst_max_(dst_max),
src_(src),
@@ -353,9 +378,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo {

class StreamSwitchTaskInfo : public TaskInfo {
public:
StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond,
int64_t data_type)
: TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH),
StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
void *value_addr, int64_t cond, int64_t data_type)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
true_stream_id_(true_stream_id),
input_addr_(input_addr),
value_addr_(value_addr),
@@ -379,8 +404,8 @@ class StreamSwitchTaskInfo : public TaskInfo {

class StreamActiveTaskInfo : public TaskInfo {
public:
StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {}
StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
~StreamActiveTaskInfo() override {}

uint32_t active_stream_id() const { return active_stream_id_; }


+ 1
- 0
inc/framework/generator/ge_generator.h View File

@@ -27,6 +27,7 @@
#include "graph/ge_tensor.h"
#include "graph/graph.h"
#include "graph/op_desc.h"
#include "graph/detail/attributes_holder.h"

namespace ge {
class GeGenerator {


+ 4
- 3
inc/framework/omg/omg.h View File

@@ -98,13 +98,14 @@ Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file);

Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format);

Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);

void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);

void UpdateOmgCtxWithParserCtx();

void UpdateParserCtxWithOmgCtx();

} // namespace ge

namespace domi {


+ 2
- 0
inc/framework/omg/omg_inner_types.h View File

@@ -94,6 +94,8 @@ struct OmgContext {
std::vector<std::pair<std::string, int32_t>> user_out_nodes;
// net out nodes (where user_out_nodes or leaf nodes)
std::vector<std::string> net_out_nodes;
// net out nodes top names(only caffe has top)
std::vector<std::string> out_top_names;
// path for the aicpu custom operator so_file
std::vector<std::string> aicpu_op_run_paths;
// ddk version


+ 9
- 1
inc/graph/compute_graph.h View File

@@ -74,6 +74,9 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A

size_t GetAllNodesSize() const;
Vistor<NodePtr> GetAllNodes() const;
// is_unknown_shape: false, same with GetAllNodes func
// is_unknown_shape: true, same with GetDirectNodes func
Vistor<NodePtr> GetNodes(bool is_unknown_shape) const;
size_t GetDirectNodesSize() const;
Vistor<NodePtr> GetDirectNode() const;
Vistor<NodePtr> GetInputNodes() const;
@@ -174,6 +177,10 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
void SetInputSize(uint32_t size) { input_size_ = size; }
uint32_t GetInputSize() const { return input_size_; }

// false: known shape true: unknow shape
bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; }
void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; }

///
/// Set is need train iteration.
/// If set true, it means this graph need to be run iteration some
@@ -282,7 +289,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
std::map<uint32_t, std::string> op_name_map_;
uint64_t session_id_ = 0;
ge::Format data_format_ = ge::FORMAT_ND;
// unknown graph indicator, default is false, mean known shape
bool is_unknown_shape_graph_ = false;
};
} // namespace ge

#endif // INC_GRAPH_COMPUTE_GRAPH_H_

+ 22
- 1
inc/graph/debug/ge_attr_define.h View File

@@ -139,6 +139,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_INPUTS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_OUTPUTS;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DIMS;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME;

@@ -776,6 +778,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET;
@@ -994,7 +1000,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE;

// used for l1 fusion and other fusion in future
// used for lX fusion
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY;
@@ -1008,9 +1014,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE;

// op overflow dump
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE;

// functional ops attr
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH;
@@ -1056,6 +1070,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOR
// for gradient group
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG;

// dynamic shape attrs
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX;

// for fusion op plugin
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE;
} // namespace ge

#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_

+ 0
- 1
inc/graph/detail/attributes_holder.h View File

@@ -149,5 +149,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder {
AnyMap extAttrs_;
};
} // namespace ge

#endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_

+ 1
- 0
inc/graph/ge_context.h View File

@@ -28,6 +28,7 @@ class GEContext {
uint32_t DeviceId();
uint64_t TraceId();
void Init();
void SetSessionId(uint64_t session_id);
void SetCtxDeviceId(uint32_t device_id);

private:


+ 5
- 2
inc/graph/ge_tensor.h View File

@@ -25,6 +25,7 @@
#include "graph/buffer.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"

namespace ge {
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape {
public:
@@ -108,8 +109,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH
DataType GetDataType() const;
void SetDataType(DataType dt);

void SetOriginDataType(DataType originDataType);
DataType GetOriginDataType() const;
void SetOriginDataType(DataType originDataType);

std::vector<uint32_t> GetRefPortIndex() const;
void SetRefPortByIndex(const std::vector<uint32_t> &index);

GeTensorDesc Clone() const;
GeTensorDesc &operator=(const GeTensorDesc &desc);
@@ -186,5 +190,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor {
GeTensorDesc &DescReference() const;
};
} // namespace ge

#endif // INC_GRAPH_GE_TENSOR_H_

+ 0
- 1
inc/graph/model_serialize.h View File

@@ -49,5 +49,4 @@ class ModelSerialize {
friend class GraphDebugImp;
};
} // namespace ge

#endif // INC_GRAPH_MODEL_SERIALIZE_H_

+ 4
- 0
inc/graph/op_desc.h View File

@@ -105,6 +105,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

GeTensorDescPtr MutableInputDesc(uint32_t index) const;

GeTensorDescPtr MutableInputDesc(const string &name) const;

Vistor<GeTensorDesc> GetAllInputsDesc() const;

Vistor<GeTensorDescPtr> GetAllInputsDescPtr() const;
@@ -127,6 +129,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

GeTensorDescPtr MutableOutputDesc(uint32_t index) const;

GeTensorDescPtr MutableOutputDesc(const string &name) const;

uint32_t GetAllOutputsDescSize() const;

Vistor<GeTensorDesc> GetAllOutputsDesc() const;


+ 11
- 4
inc/graph/utils/graph_utils.h View File

@@ -130,7 +130,7 @@ struct NodeIndexIO {
IOType io_type_ = kOut;
std::string value_;

std::string ToString() const { return value_; }
const std::string &ToString() const { return value_; }
};

class GraphUtils {
@@ -188,8 +188,8 @@ class GraphUtils {
/// @param [in] output_index
/// @return graphStatus
///
static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);
static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);

static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node);

@@ -303,6 +303,14 @@ class GraphUtils {
///
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node);

///
/// Copy all in-data edges from `src_node` to `dst_node`
/// @param src_node
/// @param dst_node
/// @return
///
static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node);

static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph);

static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec);
@@ -728,5 +736,4 @@ class PartialGraphBuilder : public ComputeGraphBuilder {
std::vector<NodePtr> exist_nodes_;
};
} // namespace ge

#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_

+ 25
- 0
inc/graph/utils/node_utils.h View File

@@ -99,6 +99,13 @@ class NodeUtils {
///
static NodePtr GetParentInput(const NodePtr &node);

///
/// @brief Check is varying_input for while node
/// @param [in] node: Data node for subgraph
/// @return bool
///
static bool IsWhileVaryingInput(const ge::NodePtr &node);

///
/// @brief Get subgraph input is constant.
/// @param [in] node
@@ -114,6 +121,24 @@ class NodeUtils {
///
static graphStatus RemoveSubgraphsOnNode(const NodePtr &node);

///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
static vector<NodePtr> GetSubgraphDataNodesByIndex(const Node &node, int index);

///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
static vector<NodePtr> GetSubgraphOutputNodes(const Node &node);

static NodePtr GetInDataNodeByIndex(const Node &node, int index);

static vector<NodePtr> GetOutDataNodesByIndex(const Node &node, int index);

private:
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_;
static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_;


+ 1
- 0
inc/graph/utils/tensor_adapter.h View File

@@ -20,6 +20,7 @@
#include <memory>
#include "graph/ge_tensor.h"
#include "graph/tensor.h"

namespace ge {
using GeTensorPtr = std::shared_ptr<GeTensor>;
using ConstGeTensorPtr = std::shared_ptr<const GeTensor>;


+ 1
- 0
inc/graph/utils/tensor_utils.h View File

@@ -21,6 +21,7 @@
#include "graph/def_types.h"
#include "graph/ge_error_codes.h"
#include "graph/ge_tensor.h"

namespace ge {
class TensorUtils {
public:


+ 1
- 0
src/common/graph/CMakeLists.txt View File

@@ -71,5 +71,6 @@ target_link_libraries(graph PRIVATE
${PROTOBUF_LIBRARY}
${c_sec}
${slog}
${error_manager}
rt
dl)

+ 13
- 0
src/common/graph/compute_graph.cc View File

@@ -106,6 +106,15 @@ ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::share
return Vistor<NodePtr>(shared_from_this(), all_nodes);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetNodes(
bool is_unknown_shape) const {
if (is_unknown_shape) {
return GetDirectNode();
} else {
return GetAllNodes();
}
}

size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const {
@@ -497,6 +506,10 @@ ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<Compute
if (name != subgraph->GetName()) {
GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str());
}
if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) {
GE_LOGE("The subgraph %s existed", name.c_str());
return GRAPH_PARAM_INVALID;
}
sub_graph_.push_back(subgraph);
names_to_subgraph_[name] = subgraph;
return GRAPH_SUCCESS;


+ 4
- 0
src/common/graph/debug/ge_op_types.h View File

@@ -34,12 +34,16 @@ GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims");
GE_REGISTER_OPTYPE(SWITCH, "Switch");
GE_REGISTER_OPTYPE(MERGE, "Merge");
GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge");
GE_REGISTER_OPTYPE(ENTER, "Enter");
GE_REGISTER_OPTYPE(REFENTER, "RefEnter");
GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration");
GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration");
GE_REGISTER_OPTYPE(CONSTANT, "Const");
GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder");
GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp");
GE_REGISTER_OPTYPE(GETNEXT, "GetNext");
GE_REGISTER_OPTYPE(INITDATA, "InitData");
GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity");
GE_REGISTER_OPTYPE(ANN_DATA, "AnnData");

GE_REGISTER_OPTYPE(CONSTANTOP, "Constant");


+ 69
- 21
src/common/graph/format_refiner.cc View File

@@ -41,11 +41,9 @@ using namespace ge;
using namespace std;
namespace ge {
namespace {
static const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE};
static bool net_format_is_nd = true;
static Format g_user_set_format = FORMAT_ND;
static bool is_first_infer = true;
static RefRelations reflection_builder;
const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE};
const string kIsGraphInferred = "_is_graph_inferred";
RefRelations reflection_builder;
} // namespace

graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection,
@@ -72,9 +70,49 @@ graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &re
return GRAPH_SUCCESS;
}

graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) {
graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) {
// 5 meas dim num
if (node_ptr->GetType() != "BiasAdd") {
return GRAPH_SUCCESS;
}
std::unordered_map<string, Format> kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}};
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) {
auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i);
GE_CHECK_NOTNULL(in_desc);
if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num
continue;
}
auto format = in_desc->GetOriginFormat();
auto key = TypeUtils::FormatToSerialString(format);
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key];
in_desc->SetOriginFormat(fixed_format);
in_desc->SetFormat(fixed_format);
GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i,
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::FormatToSerialString(fixed_format).c_str());
}
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) {
auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i);
GE_CHECK_NOTNULL(out_desc);
if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num
continue;
}
auto format = out_desc->GetOriginFormat();
auto key = TypeUtils::FormatToSerialString(format);
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key];
out_desc->SetOriginFormat(fixed_format);
out_desc->SetFormat(fixed_format);
GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i,
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::FormatToSerialString(fixed_format).c_str());
}
return GRAPH_SUCCESS;
}

graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) {
GE_CHECK_NOTNULL(graph);
GE_CHECK_NOTNULL(op_desc);
if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) {
if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) {
ConstGeTensorPtr tensor_value;
if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) {
GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str());
@@ -95,7 +133,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
}
anchor_points.clear();
// Get all anchor point nodes and switch nodes
for (const auto &node_ptr : graph->GetAllNodes()) {
for (auto &node_ptr : graph->GetAllNodes()) {
if (node_ptr == nullptr) {
return GRAPH_FAILED;
}
@@ -103,7 +141,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
if (op_desc == nullptr) {
return GRAPH_FAILED;
}
graphStatus status = RefreshConstantOutProcess(op_desc);
graphStatus status = RefreshConstantOutProcess(graph, op_desc);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "refresh constant out process failed!");
return GRAPH_FAILED;
@@ -135,6 +173,16 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
if (!node_is_all_nd) {
continue;
}
// special process for biasAdd op
// In tensorflow, biasAdd's format is alwayse NHWC even though set the arg
// "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism
// so here do special process
status = BiasAddFormatFixProcess(node_ptr);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "fix biasAdd process failed!");
return GRAPH_FAILED;
}

GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str());
anchor_points.push_back(node_ptr);
}
@@ -344,14 +392,11 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor
}
}

void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; }

graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format,
graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes,
ge::Format data_format,
std::unordered_map<ge::NodePtr, bool> &node_status) {
bool is_internal_format = TypeUtils::IsInternalFormat(data_format);
bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND);
if (!need_process) {
GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer,
if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) {
GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph),
TypeUtils::FormatToSerialString(data_format).c_str());
return GRAPH_SUCCESS;
}
@@ -410,8 +455,6 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph)
std::vector<ge::NodePtr> anchor_points;
std::vector<ge::NodePtr> data_nodes;
// global net format
net_format_is_nd = true;
g_user_set_format = FORMAT_ND;

if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "input graph is null");
@@ -448,10 +491,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph)
/// format for these data nodes.
/// Notice: ignore 5D formats
auto data_format = graph->GetDataFormat();
status = DataNodeFormatProcess(data_nodes, data_format, node_status);
// Set infer flag to false
SetInferOrigineFormatFlag(false);
status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status);
(void)AttrUtils::SetBool(graph, kIsGraphInferred, true);

return status;
}

bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) {
bool is_graph_inferred = false;
return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred);
}
} // namespace ge

+ 4
- 4
src/common/graph/format_refiner.h View File

@@ -30,10 +30,9 @@ namespace ge {
class FormatRefiner {
public:
static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph);
static void SetInferOrigineFormatFlag(bool is_first = true);

private:
static graphStatus RefreshConstantOutProcess(const OpDescPtr &op_desc);
static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc);
static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points,
std::vector<ge::NodePtr> &data_nodes,
std::unordered_map<ge::NodePtr, bool> &node_status);
@@ -43,8 +42,9 @@ class FormatRefiner {
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node,
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format,
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes,
ge::Format data_format, std::unordered_map<ge::NodePtr, bool> &node_status);
static bool IsGraphInferred(const ComputeGraphPtr &graph);
};
} // namespace ge
#endif // COMMON_GRAPH_FORMAT_REFINER_H_

+ 22
- 1
src/common/graph/ge_attr_define.cc View File

@@ -121,6 +121,8 @@ const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp";
const std::string ATTR_NAME_AIPP_INPUTS = "_aipp_inputs";
const std::string ATTR_NAME_AIPP_OUTPUTS = "_aipp_outputs";

const std::string ATTR_NAME_INPUT_DIMS = "input_dims";

const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id";
const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name";

@@ -723,6 +725,10 @@ const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name";

const std::string ATTR_MODEL_CORE_TYPE = "core_type";

const std::string ATTR_MODEL_ATC_VERSION = "atc_version";

const std::string ATTR_MODEL_OPP_VERSION = "opp_version";

// Public attribute
const std::string ATTR_NAME_IMPLY_TYPE = "imply_type";

@@ -932,7 +938,7 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace";

const std::string MODEL_ATTR_SESSION_ID = "session_id";

// l1 fusion and other fusion in future
// lx fusion
const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id";
const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key";
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key";
@@ -946,9 +952,17 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1
const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion";
const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split";
const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed";
const std::string ATTR_DATA_DUMP_REF = "_datadump_ref";
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion";
const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id";
const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion";
const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag";
const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr";
const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size";

// Op debug attrs
const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag";
const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode";

// Atomic addr clean attrs
const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index";
@@ -1013,4 +1027,11 @@ const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op";
// used for allreduce tailing optimization
const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group";
const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node";

// dynamic shape attr
const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr";
const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index";

// for fusion op plugin
const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type";
} // namespace ge

+ 11
- 0
src/common/graph/ge_tensor.cc View File

@@ -220,6 +220,7 @@ const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape";
const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format";
const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type";
const string TENSOR_UTILS_SHAPE_RANGE = "shape_range";
const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index";

GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {}

@@ -567,6 +568,16 @@ DataType GeTensorDesc::GetOriginDataType() const {
return TypeUtils::SerialStringToDataType(origin_data_type_str);
}

std::vector<uint32_t> GeTensorDesc::GetRefPortIndex() const {
vector<uint32_t> ref_port_index;
(void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index);
return ref_port_index;
}

void GeTensorDesc::SetRefPortByIndex(const std::vector<uint32_t> &index) {
(void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index);
}

graphStatus GeTensorDesc::IsValid() const {
auto dtype = this->GetDataType();
auto format = this->GetFormat();


+ 1
- 1
src/common/graph/graph.cc View File

@@ -210,7 +210,7 @@ class GraphImpl {

graphStatus FindOpByName(const string &name, ge::Operator &op) const {
auto it = op_list_.find(name);
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "Error: there is no op: %s.", name.c_str());
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str());
op = it->second;
return GRAPH_SUCCESS;
}


+ 108
- 1
src/common/graph/graph.mk View File

@@ -1,5 +1,5 @@
LOCAL_PATH := $(call my-dir)
include $(LOCAL_PATH)/stub/Makefile
COMMON_LOCAL_SRC_FILES := \
./proto/om.proto \
./proto/ge_ir.proto \
@@ -77,6 +77,7 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -85,6 +86,54 @@ LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for host
include $(CLEAR_VARS)
LOCAL_MODULE := stub/libgraph

LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2
LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for host
include $(CLEAR_VARS)
LOCAL_MODULE := fwk_stub/libgraph

LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2
LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for device
include $(CLEAR_VARS)
@@ -99,6 +148,7 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -111,6 +161,60 @@ LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_SHARED_LIBRARY)

#compiler for device
include $(CLEAR_VARS)
LOCAL_MODULE := stub/libgraph

LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

ifeq ($(device_os),android)
LOCAL_LDFLAGS := -ldl
endif

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_SHARED_LIBRARY)

#compiler for device
include $(CLEAR_VARS)
LOCAL_MODULE := fwk_stub/libgraph

LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

ifeq ($(device_os),android)
LOCAL_LDFLAGS := -ldl
endif

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_SHARED_LIBRARY)

# compile for ut/st
include $(CLEAR_VARS)
@@ -125,6 +229,7 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -150,6 +255,7 @@ LOCAL_STATIC_LIBRARIES := \
LOCAL_SHARED_LIBRARIES := \
libc_sec \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -173,6 +279,7 @@ LOCAL_STATIC_LIBRARIES := \
LOCAL_SHARED_LIBRARIES := \
libc_sec \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl



+ 12
- 13
src/common/graph/model_serialize.cc View File

@@ -88,10 +88,8 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_
}

bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) {
if (op_desc == nullptr || op_def_proto == nullptr) {
GELOGE(GRAPH_FAILED, "Input Para Invalid");
return false;
}
GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null.");
GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
if (op_desc->op_def_.GetProtoMsg() != nullptr) {
*op_def_proto = *op_desc->op_def_.GetProtoMsg();
// Delete unnecessary attr
@@ -130,16 +128,17 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
op_def_proto->add_subgraph_name(name);
}

proto::AttrDef key;
proto::AttrDef value;
for (auto &item : op_desc->output_name_idx_) {
key.mutable_list()->add_s(item.first);
value.mutable_list()->add_i(item.second);
if (!op_desc->output_name_idx_.empty()) {
proto::AttrDef key;
proto::AttrDef value;
for (auto &item : op_desc->output_name_idx_) {
key.mutable_list()->add_s(item.first);
value.mutable_list()->add_i(item.second);
}
auto op_desc_attr = op_def_proto->mutable_attr();
op_desc_attr->insert({"_output_name_key", key});
op_desc_attr->insert({"_output_name_value", value});
}
auto op_desc_attr = op_def_proto->mutable_attr();
op_desc_attr->insert({"_output_name_key", key});
op_desc_attr->insert({"_output_name_value", value});
}
return true;
}


+ 34
- 13
src/common/graph/node.cc View File

@@ -26,6 +26,7 @@
#include "utils/ge_ir_utils.h"
#include "utils/node_utils.h"
#include "utils/op_desc_utils.h"
#include "common/util/error_manager/error_manager.h"

using std::string;
using std::vector;
@@ -154,7 +155,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(cons
const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode();
const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode();
if (peer_node == nullptr || r_peer_node == nullptr) {
GELOGE(GRAPH_FAILED, "Error: anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ",
GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ",
this->GetName().c_str(), i, j);
return false;
}
@@ -434,8 +435,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const {
if (idx < 0 || idx >= static_cast<int>(in_data_anchors_.size())) {
GELOGE(GRAPH_FAILED, "the node doesn't have %d th in_data_anchor, node %s:%s", idx, GetType().c_str(),
GetName().c_str());
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
return nullptr;
} else {
return in_data_anchors_[idx];
@@ -445,7 +449,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAn
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const {
// Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_
if (idx < -1 || idx >= static_cast<int>(in_data_anchors_.size())) {
GELOGW("the node doesn't have %d th in_anchor, node %s:%s", idx, GetType().c_str(), GetName().c_str());
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "in_anchor", GetType().c_str()});
GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str());
return nullptr;
} else {
// Return control anchor
@@ -461,8 +468,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int i
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const {
// Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_
if (idx < -1 || idx >= static_cast<int>(out_data_anchors_.size())) {
GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_anchor, node %s:%s", idx, GetType().c_str(),
GetName().c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"},
{
GetName().c_str(),
std::to_string(idx),
"out_anchor",
GetType().c_str(),
});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
return nullptr;
} else {
// Return control anchor
@@ -477,8 +491,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const {
if (idx < 0 || idx >= static_cast<int>(out_data_anchors_.size())) {
GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_data_anchor, node %s:%s", idx, GetType().c_str(),
GetName().c_str());
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
return nullptr;
} else {
return out_data_anchors_[idx];
@@ -733,11 +750,15 @@ graphStatus Node::Verify() const {
GELOGW("in anchor ptr is null");
continue;
}
GE_CHK_BOOL_RET_STATUS(
op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type ||
op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) ||
in_anchor_ptr->GetPeerAnchors().size() > 0,
GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx());
bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type ||
op_->GetType() == const_type || op_->GetType() == variable_type ||
op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0;
if (!valid_anchor) {
ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"name", "index"},
{GetName(), std::to_string(in_anchor_ptr->GetIdx())});
GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx());
return GRAPH_FAILED;
}
}

string frameworkop_type = "FrameworkOp";


+ 40
- 13
src/common/graph/op_desc.cc View File

@@ -19,6 +19,7 @@
#include "debug/ge_util.h"
#include "external/graph/operator.h"
#include "framework/common/debug/ge_log.h"
#include "common/util/error_manager/error_manager.h"
#include "graph/ge_attr_value.h"
#include "graph/ge_tensor.h"
#include "graph/operator_factory_impl.h"
@@ -470,6 +471,25 @@ GeTensorDesc OpDesc::GetInputDesc(const string &name) const {
return *(inputs_desc_[it->second].get());
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const {
GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index);
if (inputs_desc_[index] == nullptr) {
return nullptr;
}
GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid");
return inputs_desc_[index];
}

GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
if (it == input_name_idx.end()) {
GELOGW("Failed to get [%s] input desc", name.c_str());
return nullptr;
}
return MutableInputDesc(it->second);
}

GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const {
auto input_name_idx = GetAllInputName();
vector<string> names;
@@ -496,15 +516,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(cons

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const {
GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index);
if (inputs_desc_[index] == nullptr) {
return nullptr;
}
GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid");
return inputs_desc_[index];
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllInputsDesc() const {
vector<GeTensorDesc> temp{};
for (const auto &it : inputs_desc_) {
@@ -609,6 +620,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu
return outputs_desc_[index];
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const {
auto it = output_name_idx_.find(name);
if (it == output_name_idx_.end()) {
GELOGW("Failed to get [%s] output desc", name.c_str());
return nullptr;
}
return MutableOutputDesc(it->second);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const {
return static_cast<uint32_t>(outputs_desc_.size());
}
@@ -882,15 +902,22 @@ graphStatus OpDesc::CommonVerify() const {
// Checking shape of all inputs
vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims();
for (int64_t dim : ishape) {
GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.",
iname.c_str());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dim < -2, ErrorManager::GetInstance().ATCReportErrMessage(
"E19014", {"opname", "value", "reason"},
{GetName(), "input " + iname + " shape", "contains negative or zero dimension"});
return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(),
iname.c_str());
}
}
// Check all attributes defined
const auto &all_attributes = GetAllAttrs();
for (const auto &name : GetAllAttrNames()) {
GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED,
"operator attribute %s is empty.", name.c_str());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
all_attributes.find(name) == all_attributes.end(),
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{GetName(), "attribute " + name, "is empty"});
return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str());
}

return GRAPH_SUCCESS;


+ 33
- 40
src/common/graph/operator.cc View File

@@ -36,6 +36,8 @@
#include "graph/op_desc.h"
#include "graph/runtime_inference_context.h"
#include "graph/usr_types.h"
#include "graph/utils/node_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "utils/graph_utils.h"
#include "utils/op_desc_utils.h"
#include "utils/tensor_adapter.h"
@@ -57,8 +59,7 @@ using std::vector;
namespace ge {
class OpIO {
public:
explicit OpIO(const string &name, int index, const OperatorImplPtr &owner)
: name_(name), index_(index), owner_(owner) {}
OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {}

~OpIO() = default;

@@ -546,56 +547,46 @@ Operator &Operator::AddControlInput(const Operator &src_oprt) {
}

graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const {
if (operator_impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "operator impl is nullptr.");
return GRAPH_FAILED;
}
ge::ConstNodePtr node_ptr = operator_impl_->GetNode();
if (node_ptr) {
GE_CHECK_NOTNULL(operator_impl_);
auto node_ptr = operator_impl_->GetNode();
if (node_ptr != nullptr) {
// For inner compute graph
auto op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "op_desc is nullptr.");
return GRAPH_FAILED;
}
GE_CHECK_NOTNULL(op_desc);
auto index = op_desc->GetInputIndexByName(dst_name);
auto in_data_anchor = node_ptr->GetInDataAnchor(index);
if (in_data_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "in_data_anchor is nullptr.");
return GRAPH_FAILED;
}
GE_CHECK_NOTNULL(in_data_anchor);
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (out_data_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "out_data_anchor is nullptr.");
return GRAPH_FAILED;
}
std::shared_ptr<Node> peer_node_ptr = out_data_anchor->GetOwnerNode();
if (peer_node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "peer_node_ptr is nullptr.");
return GRAPH_FAILED;
}
ge::OperatorImplPtr operator_impl_ptr = nullptr;
operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(peer_node_ptr);
if (operator_impl_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
return GRAPH_FAILED;
}
Operator const_op(std::move(operator_impl_ptr));
if (peer_node_ptr->GetOpDesc() != nullptr) {
const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType();
if (op_descType == CONSTANTOP) {
return const_op.GetAttr(op::Constant::name_attr_value(), data);
} else if (op_descType == CONSTANT) {
return const_op.GetAttr(op::Const::name_attr_value(), data);
GE_CHECK_NOTNULL(out_data_anchor);
auto peer_node = out_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(peer_node);
auto peer_op_desc = peer_node->GetOpDesc();
GE_CHECK_NOTNULL(peer_op_desc);
auto peer_op_type = peer_op_desc->GetType();
if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) {
auto const_op_impl = ComGraphMakeShared<OperatorImpl>(peer_node);
GE_CHECK_NOTNULL(const_op_impl);
Operator const_op(std::move(const_op_impl));
return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
} else if (peer_op_type == DATA) {
auto parent_node = NodeUtils::GetParentInput(peer_node);
while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
parent_node = NodeUtils::GetParentInput(parent_node);
}
if ((parent_node != nullptr) &&
((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
auto const_op_impl = ComGraphMakeShared<OperatorImpl>(parent_node);
GE_CHECK_NOTNULL(const_op_impl);
Operator const_op(std::move(const_op_impl));
return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
}
}

// Try get from runtime inference context
auto session_id = std::to_string(GetContext().SessionId());
RuntimeInferenceContext *runtime_infer_ctx = nullptr;
if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) {
GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str());
auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data);
auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data);
if (ret == GRAPH_SUCCESS) {
return GRAPH_SUCCESS;
}
@@ -604,6 +595,8 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co
// For outer graph
return GetInputConstDataOut(dst_name, data);
}
auto op_name = operator_impl_->GetName();
GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str());
return GRAPH_FAILED;
}
graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const {


+ 2
- 0
src/common/graph/option/ge_context.cc View File

@@ -85,6 +85,8 @@ uint32_t GEContext::DeviceId() { return device_id_; }

uint64_t GEContext::TraceId() { return trace_id_; }

void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; }

void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; }

} // namespace ge

+ 4
- 0
src/common/graph/ref_relation.cc View File

@@ -242,6 +242,10 @@ void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &r
int sub_graph_idx = 0;
for (const auto &name : sub_graph_names) {
auto sub_graph = root_graph.GetSubgraph(name);
if (sub_graph == nullptr) {
GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str());
continue;
}
for (const auto &sub_graph_node : sub_graph->GetDirectNode()) {
auto sub_graph_node_type = sub_graph_node->GetType();



+ 156
- 16
src/common/graph/shape_refiner.cc View File

@@ -37,6 +37,115 @@

namespace ge {
namespace {
const uint32_t kWhileBodySubGraphIdx = 1;

graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) {
GELOGD("Enter reverse brush while body subgraph process!");

auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx);
if (sub_graph_body == nullptr) {
GELOGE(GRAPH_FAILED, "Get while body graph failed!");
return GRAPH_FAILED;
}

for (const auto &node_sub : sub_graph_body->GetAllNodes()) {
if (node_sub->GetInDataNodes().size() == 0) {
continue;
}

for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) {
auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i);
(void)input_desc->SetUnknownDimNumShape();
}
for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) {
auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i);
(void)output_desc->SetUnknownDimNumShape();
}
}

return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
GELOGD("Enter update parent node shape for class branch op process");
// check sub_graph shape.If not same ,do unknown shape process
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].empty()) {
continue;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
for (auto &tensor : ref_out_tensors[i]) {
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
return GRAPH_FAILED;
}
auto shape = tensor.MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
break;
}
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
continue;
}
GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
}
}
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
}
return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
GELOGD("Enter update parent node shape for class while op process");
if (ref_data_tensors.size() != ref_out_tensors.size()) {
GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(),
ref_data_tensors.size(), ref_out_tensors.size());
return GRAPH_FAILED;
}
for (size_t i = 0; i < ref_data_tensors.size(); i++) {
if (ref_out_tensors[i].size() != 1) {
GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!");
return GRAPH_FAILED;
}
}
bool is_need_reverse_brush = false;
// check input and output
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].empty()) {
continue;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
auto tmp_shape = ref_out_tensor.MutableShape();
// ref_i's data and output tensor shape should be same
for (auto &tensor : ref_data_tensors[i]) {
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str());
return GRAPH_FAILED;
}
auto shape = tensor.MutableShape();
if (shape.GetDims() != tmp_shape.GetDims()) {
ref_out_tensor.SetUnknownDimNumShape();
is_need_reverse_brush = true;
break;
}
}
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
}
// reverse refresh while body shape
if (is_need_reverse_brush) {
return ReverseBrushWhileBodySubGraph(node);
}
return GRAPH_SUCCESS;
}

graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
@@ -98,6 +207,37 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
}
return GRAPH_SUCCESS;
}

graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput,
const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
auto sub_nodes = sub_graph->GetDirectNode();
for (size_t i = sub_nodes.size(); i > 0; --i) {
auto sub_node = sub_nodes.at(i - 1);
if (sub_node->GetType() == NETOUTPUT) {
netoutput = sub_node;
}
if (sub_node->GetType() == DATA) {
if (sub_node->GetOpDesc() == nullptr) {
return GRAPH_FAILED;
}

int ref_i;
if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
return GRAPH_FAILED;
}
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(),
ref_i, node->GetAllInDataAnchorsSize());
return GRAPH_FAILED;
}
ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
}
}
return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
@@ -105,7 +245,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
return GRAPH_SUCCESS;
}

std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());

for (const auto &name : sub_graph_names) {
if (name.empty()) {
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
@@ -117,13 +260,9 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
return GRAPH_FAILED;
}
NodePtr netoutput = nullptr;
auto sub_nodes = sub_graph->GetDirectNode();
for (size_t i = sub_nodes.size(); i > 0; --i) {
auto sub_node = sub_nodes.at(i - 1);
if (sub_node->GetType() == NETOUTPUT) {
netoutput = sub_node;
break;
}
auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
if (ret != GRAPH_SUCCESS) {
return ret;
}
if (netoutput == nullptr) {
GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str());
@@ -150,19 +289,17 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
continue;
}
GELOGI("Parent node index of edge desc is %d", ref_i);
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i));
if (output_desc == nullptr) {
GE_LOGE(
"The ref index(%d) on the input %d of netoutput %s on the sub graph %s "
"parent node %s are incompatible, outputs num %u",
ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(),
node->GetAllOutDataAnchorsSize());
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
return GRAPH_FAILED;
}
op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc);
ref_out_tensors[ref_i].emplace_back(*edge_desc);
}
}
return GRAPH_SUCCESS;

if (node->GetType() == WHILE) {
return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors);
}
return UpdateParentNodeForBranch(node, ref_out_tensors);
}
} // namespace
void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) {
@@ -170,6 +307,9 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
GELOGE(GRAPH_FAILED, "node is null");
return;
}
if (!IsLogEnable(GE, DLOG_DEBUG)) {
return;
}
ge::OpDescPtr op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return );
std::string str;


+ 5
- 5
src/common/graph/utils/ge_ir_utils.h View File

@@ -1,18 +1,18 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
*/

#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_
#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_


+ 55
- 24
src/common/graph/utils/graph_utils.cc View File

@@ -38,6 +38,7 @@
#include "utils/ge_ir_utils.h"
#include "utils/node_utils.h"
#include "debug/ge_op_types.h"
#include "external/ge/ge_api_types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
@@ -410,8 +411,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTra
/// @return graphStatus
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) {
GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) {
GE_CHECK_NOTNULL(src);
GE_CHECK_NOTNULL(insert_node);

@@ -570,7 +571,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons
static int max_dumpfile_num = 0;
if (max_dumpfile_num == 0) {
string opt = "0";
(void)GetContext().GetOption("ge.maxDumpFileNum", opt);
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt);
max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue);
}
if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) {
@@ -670,7 +671,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText
if (maxDumpFileSize == 0) {
string opt = "0";
// Can not check return value
(void)GetContext().GetOption("ge.maxDumpFileSize", opt);
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt);
maxDumpFileSize = atol(opt.c_str());
}
if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) {
@@ -740,7 +741,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn
static int max_dumpfile_num = 0;
if (max_dumpfile_num == 0) {
string opt = "0";
(void)GetContext().GetOption("ge.maxDumpFileNum", opt);
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt);
max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue);
}
if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) {
@@ -920,7 +921,7 @@ graphStatus RelinkDataIO(const NodePtr &node, const std::vector<int> &io_map, In
InNodesToOut GetFullConnectIONodes(const NodePtr &node) {
InNodesToOut in_nodes_to_out;
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Node is nullptr,node is %s", node->GetName().c_str());
GELOGE(GRAPH_FAILED, "Node is nullptr");
return in_nodes_to_out;
}
auto in_nodes_list = node->GetInNodes();
@@ -1308,6 +1309,36 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCt
return GRAPH_SUCCESS;
}

///
/// Copy all in-data edges from `src_node` to `dst_node`.
/// @param src_node
/// @param dst_node
/// @return
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node,
NodePtr &dst_node) {
if ((src_node == nullptr) || (dst_node == nullptr)) {
GELOGE(GRAPH_FAILED, "Parameter is nullptr");
return GRAPH_PARAM_INVALID;
}
auto src_data_in_nodes = src_node->GetInDataNodes();
if (src_data_in_nodes.empty()) {
return GRAPH_SUCCESS;
}
for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) {
auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx());
auto ret =
GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s",
in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(),
src_node->GetName().c_str(), dst_node->GetName().c_str());
return ret;
}
}
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph,
const NodePtr &node) {
if (graph->AddInputNode(node) == nullptr) {
@@ -1339,7 +1370,7 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol) {
GE_CHECK_NOTNULL(graph);
for (auto &node : graph->GetAllNodes()) {
for (const auto &node : graph->GetAllNodes()) {
// in_data_anchor
if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) {
GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str());
@@ -1396,16 +1427,16 @@ graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node,
return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol);
}

std::string type = node->GetType();
const std::string &type = node->GetType();
if ((type == MERGE) || (type == STREAMMERGE)) {
return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol);
}

for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn);
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
std::string symbol = cur_node_info.ToString();
const std::string &symbol = cur_node_info.ToString();
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str());
symbol_to_anchors[symbol] = {cur_node_info};
anchor_to_symbol[symbol] = symbol;
@@ -1432,7 +1463,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol) {
GE_CHECK_NOTNULL(node);
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut);
if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) {
continue;
@@ -1446,7 +1477,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node,
return GRAPH_FAILED;
}
} else {
std::string symbol = cur_node_info.ToString();
const std::string &symbol = cur_node_info.ToString();
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str());
symbol_to_anchors.emplace(std::make_pair(symbol, std::list<NodeIndexIO>{cur_node_info}));
anchor_to_symbol.emplace(std::make_pair(symbol, symbol));
@@ -1506,7 +1537,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,
GE_CHECK_NOTNULL(node);
std::vector<NodeIndexIO> exist_node_infos;
std::vector<NodeIndexIO> cur_node_infos;
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
std::string next_name;
@@ -1529,10 +1560,10 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,

size_t anchor_nums = 0;
NodeIndexIO max_node_index_io(nullptr, 0, kOut);
for (auto &temp_node_info : exist_node_infos) {
for (const auto &temp_node_info : exist_node_infos) {
auto iter1 = anchor_to_symbol.find(temp_node_info.ToString());
if (iter1 != anchor_to_symbol.end()) {
std::string temp_symbol = iter1->second;
const std::string &temp_symbol = iter1->second;
auto iter2 = symbol_to_anchors.find(temp_symbol);
if (iter2 != symbol_to_anchors.end()) {
if (iter2->second.size() > anchor_nums) {
@@ -1544,7 +1575,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,
}

std::string symbol;
for (auto &temp_node_info : exist_node_infos) {
for (const auto &temp_node_info : exist_node_infos) {
if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) !=
GRAPH_SUCCESS) ||
symbol.empty()) {
@@ -1556,7 +1587,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,

auto iter = symbol_to_anchors.find(symbol);
if (iter != symbol_to_anchors.end()) {
for (auto &temp_node_info : cur_node_infos) {
for (const auto &temp_node_info : cur_node_infos) {
GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str());
iter->second.emplace_back(temp_node_info);
anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol));
@@ -1584,7 +1615,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node,

OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);

@@ -1627,8 +1658,8 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node,
graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol) {
std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()];
std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()];
const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()];
const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()];
if (symbol1 == symbol2) {
symbol = symbol1;
GELOGI("no need to union.");
@@ -1684,7 +1715,7 @@ graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const
return GRAPH_FAILED;
}

std::string symbol = iter1->second;
const std::string &symbol = iter1->second;
auto iter2 = symbol_to_anchors.find(symbol);
if (iter2 == symbol_to_anchors.end()) {
GE_LOGE("symbol %s not found.", symbol.c_str());
@@ -1712,7 +1743,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t

// pass-through op
NodePtr node = out_data_anchor->GetOwnerNode();
std::string type = node->GetType();
const std::string &type = node->GetType();
const std::set<std::string> pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE};
if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) {
reuse_in_index = output_index;
@@ -1755,7 +1786,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t
uint32_t reuse_input_index = 0;
if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) {
reuse_in_index = static_cast<int32_t>(reuse_input_index);
GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index,
GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index,
reuse_in_index);
return true;
}
@@ -2297,7 +2328,7 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &
return;
}

std::string name = node->GetName() + "_RetVal";
std::string name = node->GetName() + "_RetVal_" + std::to_string(index);
OpDescPtr ret_val_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, FRAMEWORKOP));
if (ret_val_desc == nullptr) {
error_code = GRAPH_FAILED;


+ 153
- 18
src/common/graph/utils/node_utils.cc View File

@@ -296,15 +296,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer
return GRAPH_FAILED;
}
for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx());
ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast<uint32_t>(output_tensor.GetShape().GetDims().size()));
output_tensor.SetOriginShape(output_tensor.GetShape());
output_tensor.SetOriginDataType(output_tensor.GetDataType());
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
if (!is_unknown_graph) {
output_tensor->SetOriginShape(output_tensor->GetShape());
output_tensor->SetOriginDataType(output_tensor->GetDataType());
}
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str());
(void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor);
node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
@@ -316,17 +319,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer
continue;
}
GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(),
output_tensor.GetDataType(), output_tensor.GetOriginDataType());
peer_input_desc->SetShape(output_tensor.GetShape());
peer_input_desc->SetOriginShape(output_tensor.GetOriginShape());
peer_input_desc->SetDataType(output_tensor.GetDataType());
peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType());
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(),
output_tensor->GetDataType(), output_tensor->GetOriginDataType());
peer_input_desc->SetShape(output_tensor->GetShape());
peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
peer_input_desc->SetDataType(output_tensor->GetDataType());
peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)output_tensor.GetShapeRange(shape_range);
(void)output_tensor->GetShapeRange(shape_range);
peer_input_desc->SetShapeRange(shape_range);
ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
static_cast<uint32_t>(output_tensor.GetShape().GetDims().size()));
static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
@@ -401,10 +404,13 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const
graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
auto desc = node.GetOpDesc();
GE_CHECK_NOTNULL(desc);

// check self
is_unknow = OpShapeIsUnknown(desc);
if (is_unknow) {
return GRAPH_SUCCESS;
}
auto sub_graph_names = desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
is_unknow = OpShapeIsUnknown(desc);
return GRAPH_SUCCESS;
} else {
auto owner_graph = node.GetOwnerComputeGraph();
@@ -555,6 +561,53 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) {
return peer_out_anchor->GetOwnerNode();
}

///
/// @brief Check is varying_input for while node
/// @param [in] node: Data node for subgraph
/// @return bool
///
bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
if (node == nullptr) {
return false;
}
if (node->GetType() != DATA) {
return false; // not input_node for subgraph
}

const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
if (parent_node == nullptr) {
return false; // root graph
}

if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
return false; // not input_node for while subgraph
}

uint32_t index_i = 0;
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
return false;
}
bool varying_flag = true;
for (const auto &item : node->GetOutDataNodesAndAnchors()) {
if (item.first->GetType() != NETOUTPUT) {
continue;
}
OpDescPtr op_desc = item.first->GetOpDesc();
uint32_t index_o = 0;
if ((op_desc == nullptr) ||
!AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
continue; // input for while-cond subgraph
}
if (index_i != index_o) {
continue; // varying input for while-body subgraph
}
varying_flag = false;
break;
}
return varying_flag;
}

///
/// @brief Get subgraph input is constant.
/// @param [in] node
@@ -637,4 +690,86 @@ Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {

return GRAPH_SUCCESS;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
vector<NodePtr> in_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
return in_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
int parent_index = 0;
if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
(void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
if (parent_index == index) {
in_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
}
return in_data_node_vec;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
vector<NodePtr> out_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
return out_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
out_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
return out_data_node_vec;
}

NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) {
if (node.GetInDataAnchor(index) == nullptr) {
return nullptr;
}
if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
return nullptr;
}
return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
}

vector<NodePtr> NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) {
vector<NodePtr> out_data_nodes;
auto out_data_anchor = node.GetOutDataAnchor(index);
if (out_data_anchor == nullptr) {
return out_data_nodes;
}
for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
if (peer_in_anchor == nullptr) {
continue;
}
if (peer_in_anchor->GetOwnerNode() == nullptr) {
continue;
}
out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
}
return out_data_nodes;
}
} // namespace ge

+ 46
- 20
src/common/graph/utils/op_desc_utils.cc View File

@@ -197,24 +197,33 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::
continue;
}
auto in_node = out_anchor->GetOwnerNode();
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
ret.push_back(in_node);
} else if (in_node->GetType() == DATA) {
const ComputeGraphPtr &graph = node.GetOwnerComputeGraph();
GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null");

const NodePtr &parent_node = graph->GetParentNode();
if (parent_node == nullptr) {
continue; // Root graph.
}

if (kWhileOpTypes.count(parent_node->GetType()) > 0) {
continue; // Subgraph of While cond or body.
while (true) {
if (in_node == nullptr) {
break;
}

NodePtr input_node = NodeUtils::GetParentInput(in_node);
if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) {
ret.push_back(input_node);
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
ret.push_back(in_node);
break;
} else if (in_node->GetType() == DATA) {
if (NodeUtils::IsWhileVaryingInput(in_node)) {
break;
}
in_node = NodeUtils::GetParentInput(in_node);
} else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) {
bool is_constant = false;
(void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant);
if (!is_constant) {
break;
}
// Enter node has and only has one input
if (in_node->GetInDataNodes().size() != 1) {
GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(),
in_node->GetInDataNodes().size());
break;
}
in_node = in_node->GetInDataNodes().at(0);
} else {
break;
}
}
}
@@ -435,10 +444,27 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) {
vector<GeTensorPtr> ret;
GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return ret, "node.GetOpDesc is nullptr!");
auto op_desc = node.GetOpDesc();
GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!");
// Place holder operator, try to get the weight from parent node
// when parent node is const operator
if (node.GetType() == PLACEHOLDER) {
std::string parent_op;
(void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op);
// This if judgment is necessary because the current subgraph optimization is multithreaded
// and the parent node of the PLD operation should be a stable type, such as const
if (parent_op == CONSTANT || parent_op == CONSTANTOP) {
NodePtr parent_node = nullptr;
parent_node = op_desc->TryGetExtAttr("parentNode", parent_node);
if (parent_node != nullptr) {
op_desc = parent_node->GetOpDesc();
GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str());
}
}
}
// Const operator, take the weight directly
if (node.GetOpDesc()->GetType() == CONSTANT || (node.GetOpDesc()->GetType() == CONSTANTOP)) {
auto weight = MutableWeights(node.GetOpDesc());
if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) {
auto weight = MutableWeights(op_desc);
if (weight == nullptr) {
GELOGI("const op has no weight, op name:%s", node.GetName().c_str());
return ret;


+ 6
- 2
src/common/graph/utils/tensor_utils.cc View File

@@ -19,6 +19,7 @@

#include "debug/ge_log.h"
#include "framework/common/debug/ge_log.h"
#include "common/util/error_manager/error_manager.h"
#include "graph/ge_tensor.h"
#include "graph/types.h"
#include "graph/utils/type_utils.h"
@@ -105,7 +106,10 @@ static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_
element_cnt = 1;
for (int64_t dim : dims) {
if (CheckMultiplyOverflowInt64(element_cnt, dim)) {
GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim);
ErrorManager::GetInstance().ATCReportErrMessage(
"E19013", {"function", "var1", "var2"},
{"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)});
GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim);
return GRAPH_FAILED;
}
element_cnt *= dim;
@@ -273,7 +277,6 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format
case FORMAT_FRACTAL_Z:
graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt);
break;
case FORMAT_NC1HWC0_C04:
case FORMAT_FRACTAL_NZ:
case FORMAT_FRACTAL_ZZ:
case FORMAT_NDHWC:
@@ -285,6 +288,7 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format
case FORMAT_NDC1HWC0:
case FORMAT_FRACTAL_Z_C04:
case FORMAT_FRACTAL_ZN_LSTM:
case FORMAT_NC1HWC0_C04:
graph_status = CalcElementCntByDims(dims, element_cnt);
break;
default:


+ 2
- 1
src/common/graph/utils/type_utils.cc View File

@@ -147,7 +147,8 @@ static const std::map<std::string, Format> kStringToFormatMap = {
{"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM},
{"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G},
{"FORMAT_RESERVED", FORMAT_RESERVED},
{"ALL", FORMAT_ALL}};
{"ALL", FORMAT_ALL},
{"NULL", FORMAT_NULL}};

static const std::map<DataType, std::string> kDataTypeToStringMap = {
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set.


+ 5
- 2
src/ge/CMakeLists.txt View File

@@ -60,6 +60,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"common/formats/formats.cc"
"common/formats/utils/formats_trans_utils.cc"
"common/fp16_t.cc"
"common/ge/op_tiling_manager.cc"
"common/ge/plugin_manager.cc"
"common/helper/model_cache_helper.cc"
"common/profiling/profiling_manager.cc"
@@ -94,7 +95,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc"
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc"
"graph/load/new_model_manager/task_info/task_info.cc"
"graph/load/output/output.cc"
"graph/manager/*.cc"
"graph/manager/model_manager/event_manager.cc"
"graph/manager/util/debug.cc"
@@ -159,8 +159,11 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"hybrid/node_executor/aicpu/aicpu_ext_info.cc"
"hybrid/node_executor/aicpu/aicpu_node_executor.cc"
"hybrid/node_executor/compiledsubgraph/known_node_executor.cc"
"hybrid/node_executor/controlop/control_op_executor.cc"
"hybrid/node_executor/hccl/hccl_node_executor.cc"
"hybrid/node_executor/hostcpu/ge_local_node_executor.cc"
"hybrid/node_executor/node_executor.cc"
"hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc"
"hybrid/node_executor/task_context.cc"
"init/gelib.cc"
"model/ge_model.cc"
@@ -204,6 +207,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"common/formats/formats.cc"
"common/formats/utils/formats_trans_utils.cc"
"common/fp16_t.cc"
"common/ge/op_tiling_manager.cc"
"common/ge/plugin_manager.cc"
"common/helper/model_cache_helper.cc"
"common/profiling/profiling_manager.cc"
@@ -236,7 +240,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc"
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc"
"graph/load/new_model_manager/task_info/task_info.cc"
"graph/load/output/output.cc"
"graph/manager/*.cc"
"graph/manager/model_manager/event_manager.cc"
"graph/manager/util/debug.cc"


+ 14
- 43
src/ge/client/ge_api.cc View File

@@ -28,6 +28,7 @@
#include "graph/opsproto_manager.h"
#include "graph/utils/type_utils.h"
#include "graph/manager/util/rt_context_util.h"
#include "graph/common/ge_call_wrapper.h"
#include "register/op_registry.h"
#include "common/ge/tbe_plugin_manager.h"

@@ -41,8 +42,8 @@ namespace {
const int32_t kMaxStrLen = 128;
}

static bool kGeInitialized = false;
static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use
static bool g_ge_initialized = false;
static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use

namespace ge {
void GetOpsProtoPath(std::string &opsproto_path) {
@@ -61,31 +62,6 @@ void GetOpsProtoPath(std::string &opsproto_path) {
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
}

Status CheckDumpAndReuseMemory(const std::map<string, string> &options) {
const int kDecimal = 10;
auto dump_op_env = std::getenv("DUMP_OP");
int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0;
auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory");
if (disableReuseMemoryIter != options.end()) {
if (disableReuseMemoryIter->second == "0") {
GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open");
if (dump_op_flag) {
GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0");
}
} else if (disableReuseMemoryIter->second == "1") {
GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close");
} else {
GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid");
return FAILED;
}
} else {
if (dump_op_flag) {
GELOGW("Will dump incorrect op data with default reuse memory");
}
}
return SUCCESS;
}

Status CheckOptionsValid(const std::map<string, string> &options) {
// check job_id is valid
auto job_id_iter = options.find(OPTION_EXEC_JOB_ID);
@@ -96,11 +72,6 @@ Status CheckOptionsValid(const std::map<string, string> &options) {
}
}

// Check ge.exec.disableReuseMemory and env DUMP_OP
if (CheckDumpAndReuseMemory(options) != SUCCESS) {
return FAILED;
}

return SUCCESS;
}

@@ -108,7 +79,7 @@ Status CheckOptionsValid(const std::map<string, string> &options) {
Status GEInitialize(const std::map<string, string> &options) {
GELOGT(TRACE_INIT, "GEInitialize start");
// 0.check init status
if (kGeInitialized) {
if (g_ge_initialized) {
GELOGW("GEInitialize is called more than once");
return SUCCESS;
}
@@ -147,9 +118,9 @@ Status GEInitialize(const std::map<string, string> &options) {
}

// 7.check return status, return
if (!kGeInitialized) {
if (!g_ge_initialized) {
// Initialize success, first time calling initialize
kGeInitialized = true;
g_ge_initialized = true;
}

GELOGT(TRACE_STOP, "GEInitialize finished");
@@ -160,12 +131,12 @@ Status GEInitialize(const std::map<string, string> &options) {
Status GEFinalize() {
GELOGT(TRACE_INIT, "GEFinalize start");
// check init status
if (!kGeInitialized) {
if (!g_ge_initialized) {
GELOGW("GEFinalize is called before GEInitialize");
return SUCCESS;
}

std::lock_guard<std::mutex> lock(kGeReleaseMutex);
std::lock_guard<std::mutex> lock(g_ge_release_mutex);
// call Finalize
Status ret = SUCCESS;
Status middle_ret;
@@ -187,10 +158,10 @@ Status GEFinalize() {
ret = middle_ret;
}

if (kGeInitialized && ret == SUCCESS) {
if (g_ge_initialized && ret == SUCCESS) {
// Unified destruct rt_context
RtContextUtil::GetInstance().DestroyrtContexts();
kGeInitialized = false;
RtContextUtil::GetInstance().DestroyAllRtContexts();
g_ge_initialized = false;
}

GELOGT(TRACE_STOP, "GEFinalize finished");
@@ -202,7 +173,7 @@ Session::Session(const std::map<string, string> &options) {
GELOGT(TRACE_INIT, "Session Constructor start");
// check init status
sessionId_ = 0;
if (!kGeInitialized) {
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED);
return;
}
@@ -232,13 +203,13 @@ Session::Session(const std::map<string, string> &options) {
Session::~Session() {
GELOGT(TRACE_INIT, "Session Destructor start");
// 0.check init status
if (!kGeInitialized) {
if (!g_ge_initialized) {
GELOGW("GE is not yet initialized or is finalized.");
return;
}

Status ret = FAILED;
std::lock_guard<std::mutex> lock(kGeReleaseMutex);
std::lock_guard<std::mutex> lock(g_ge_release_mutex);
try {
uint64_t session_id = sessionId_;
// call DestroySession


+ 6
- 5
src/ge/common/convert/pb2json.cc View File

@@ -72,9 +72,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons
void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
if (field == nullptr || reflection == nullptr) {
return;
}
switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetMessage(message, field);
@@ -118,8 +115,12 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr

case ProtobufFieldDescriptor::TYPE_FLOAT:
char str[kSignificantDigits];
sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field));
json[field->name()] = str;
if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) {
json[field->name()] = str;
} else {
json[field->name()] = reflection->GetFloat(message, field);
}

break;

case ProtobufFieldDescriptor::TYPE_STRING:


+ 0
- 1
src/ge/common/formats/format_transfers/datatype_transfer.cc View File

@@ -29,7 +29,6 @@

namespace ge {
namespace formats {

namespace {
enum DataTypeTransMode {
kTransferWithDatatypeFloatToFloat16,


+ 0
- 1
src/ge/common/formats/format_transfers/datatype_transfer.h View File

@@ -27,7 +27,6 @@

namespace ge {
namespace formats {

struct CastArgs {
const uint8_t *data;
size_t src_data_size;


+ 0
- 1
src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc View File

@@ -179,6 +179,5 @@ Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::v
}

REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D)

} // namespace formats
} // namespace ge

+ 0
- 1
src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc View File

@@ -180,6 +180,5 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, con
}

REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE)

} // namespace formats
} // namespace ge

+ 1
- 1
src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc View File

@@ -56,7 +56,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap
dst_shape.clear();
hw_shape.clear();
auto w0 = GetCubeSizeByDataType(data_type);
auto h0 = GetCubeSizeByDataType(data_type);
int64_t h0 = kCubeSize;
switch (src_shape.size()) {
case 1:
dst_shape.push_back(Ceil(src_shape[0], w0));


+ 55
- 40
src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -19,6 +19,7 @@
#include <securec.h>
#include <memory>

#include "common/debug/log.h"
#include "common/formats/utils/formats_definitions.h"
#include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h"
@@ -107,8 +108,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {

int64_t hw = h * w;
int64_t chw = c * hw;
int64_t hwc0 = hw * c0;
int64_t nchw = n * chw;
int64_t hwc0 = hw * c0;

// horizontal fractal matrix count (N)
int64_t hf_cnt = Ceil(n, static_cast<int64_t>(kNiSize));
@@ -119,18 +120,15 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
int size = GetSizeByDataType(args.src_data_type);
int64_t dst_size = total_ele_cnt * size;
if (dst_size == 0) {
result.length = static_cast<size_t>(dst_size);
return SUCCESS;
}
GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;);

std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return OUT_OF_MEMORY;
}
return OUT_OF_MEMORY;);

for (int64_t vfi = 0; vfi < vf_cnt; vfi++) {
// vertical fractal matrix base index
@@ -156,12 +154,20 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
auto protected_size = dst_size - offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
errno_t ret;
errno_t ret = EOK;
if (need_pad_zero) {
ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size));
} else {
ret = memcpy_s(dst.get() + offset, static_cast<size_t>(protected_size), args.data + src_offset * size,
static_cast<size_t>(size));
if (protected_size < size) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
protected_size, size);
return INTERNAL_ERROR;
}
char *dst_data = reinterpret_cast<char *>(dst.get() + offset);
const char *src_data = reinterpret_cast<const char *>(args.data + src_offset * size);
for (int64_t index = 0; index < size; index++) {
*dst_data++ = *src_data++;
}
}
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset,
@@ -199,18 +205,15 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
dst_size *= dim;
}
dst_size *= data_size;
if (dst_size == 0) {
result.length = static_cast<size_t>(dst_size);
return SUCCESS;
}
GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;);

std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return OUT_OF_MEMORY;
}
return OUT_OF_MEMORY;);

for (int64_t c1i = 0; c1i < c1; c1i++) {
for (int64_t hi = 0; hi < h; hi++) {
@@ -223,14 +226,22 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n);
errno_t ret;
errno_t ret = EOK;
if (pad_zero) {
ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0,
static_cast<size_t>(data_size));
} else {
if (protected_size < data_size) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
protected_size, data_size);
return INTERNAL_ERROR;
}
int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i;
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size),
args.data + src_idx * data_size, static_cast<size_t>(data_size));
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset);
const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size);
for (int64_t index = 0; index < data_size; index++) {
*dst_data++ = *src_data++;
}
}
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
@@ -269,18 +280,15 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
dst_size *= dim;
}
dst_size *= data_size;
if (dst_size == 0) {
result.length = static_cast<size_t>(dst_size);
return SUCCESS;
}
GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;);

std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return OUT_OF_MEMORY;
}
return OUT_OF_MEMORY;);

for (int64_t c1i = 0; c1i < c1; c1i++) {
for (int64_t hi = 0; hi < h; hi++) {
@@ -293,14 +301,22 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n);
errno_t ret;
errno_t ret = EOK;
if (pad_zero) {
ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0,
static_cast<size_t>(data_size));
} else {
if (protected_size < data_size) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
protected_size, data_size);
return INTERNAL_ERROR;
}
int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i);
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size),
args.data + src_idx * data_size, static_cast<size_t>(data_size));
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset);
const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size);
for (int64_t index = 0; index < data_size; index++) {
*dst_data++ = *src_data++;
}
}
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
@@ -337,16 +353,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
return PARAM_INVALID;
}

if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatFromNchwToFz(args, result);
if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatNhwcToFz(args, result);
}

if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatHwcnToFz(args, result);
}

if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatNhwcToFz(args, result);
if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatFromNchwToFz(args, result);
}

return UNSUPPORTED;
@@ -358,14 +374,14 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i
return UNSUPPORTED;
}

if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNchwToFz(src_shape, data_type, dst_shape);
if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNhwcToFz(src_shape, data_type, dst_shape);
}
if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeHwcnToFz(src_shape, data_type, dst_shape);
}
if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNhwcToFz(src_shape, data_type, dst_shape);
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNchwToFz(src_shape, data_type, dst_shape);
}

return UNSUPPORTED;
@@ -374,6 +390,5 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NCHW, FORMAT_FRACTAL_Z)
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_HWCN, FORMAT_FRACTAL_Z)
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NHWC, FORMAT_FRACTAL_Z)

} // namespace formats
} // namespace ge

+ 1
- 3
src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc View File

@@ -39,7 +39,6 @@
namespace ge {
namespace formats {
namespace {

constexpr int64_t kMaxDimsNumC = 4;

Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; }
@@ -109,7 +108,7 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) {
return NOT_CHANGED;
}

/* prepare for padding in chw*/
// prepare for padding in chw
int64_t tmp = h * w * c;
int64_t n_o = Ceil(n, static_cast<int64_t>(c0));
int64_t c_o = c0;
@@ -309,6 +308,5 @@ Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vecto
}

REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04)

} // namespace formats
} // namespace ge

+ 0
- 2
src/ge/common/formats/utils/formats_definitions.h View File

@@ -19,7 +19,6 @@

namespace ge {
namespace formats {

static const int kCubeSize = 16;
static const int kNiSize = 16;
static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL;
@@ -47,7 +46,6 @@ enum FracZDimIndex { kFracZHWC1, kFracZN0, kFracZNi, kFracZC0, kFracZDimsNum };
enum DhwcnDimIndex { kDhwcnD, kDhwcnH, kDhwcnW, kDhwcnC, kDhwcnN, kDhwcnDimsNum };

enum DhwncDimIndex { kDhwncD, kDhwncH, kDhwncW, kDhwncN, kDhwncC, kDhwncDimsNum };

} // namespace formats
} // namespace ge
#endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_

+ 0
- 1
src/ge/common/formats/utils/formats_trans_utils.h View File

@@ -69,7 +69,6 @@ T Ceil(T n1, T n2) {
}
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
}

} // namespace formats
} // namespace ge
#endif // GE_COMMON_FORMATS_UTILS_FORMATS_TRANS_UTILS_H_

+ 1
- 1
src/ge/common/fp16_t.h View File

@@ -600,5 +600,5 @@ int16_t GetManBitLength(T man) {
}
return len;
}
}; // namespace ge
} // namespace ge
#endif // GE_COMMON_FP16_T_H_

+ 81
- 0
src/ge/common/ge/op_tiling_manager.cc View File

@@ -0,0 +1,81 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/ge/op_tiling_manager.h"
#include "framework/common/debug/log.h"
#include <string>

namespace {
const char *const kEnvName = "ASCEND_OPP_PATH";
const std::string kDefaultPath = "/usr/local/Ascend/opp";
const std::string kDefaultBuiltInTilingPath = "/op_impl/built-in/liboptiling.so";
const std::string kDefaultCustomTilingPath = "/op_impl/custom/liboptiling.so";
const uint8_t kPrefixIndex = 9;
} // namespace

namespace ge {
void OpTilingManager::ClearHandles() noexcept {
for (const auto &handle : handles_) {
if (dlclose(handle.second) != 0) {
GELOGE(FAILED, "Failed to close handle of %s: %s", handle.first.c_str(), dlerror());
}
}
handles_.clear();
}

OpTilingManager::~OpTilingManager() { ClearHandles(); }

std::string OpTilingManager::GetPath() {
const char *opp_path_env = std::getenv(kEnvName);
std::string opp_path = kDefaultPath;
if (opp_path_env != nullptr) {
char resolved_path[PATH_MAX];
if (realpath(opp_path_env, resolved_path) == NULL) {
GELOGE(PARAM_INVALID, "Failed load tiling lib as env 'ASCEND_OPP_PATH'(%s) is invalid path.", opp_path_env);
return std::string();
}
opp_path = resolved_path;
}
return opp_path;
}

void OpTilingManager::LoadSo() {
std::string opp_path = GetPath();
if (opp_path.empty()) {
GELOGW("Skip load tiling lib.");
return;
}
std::string built_in_tiling_lib = opp_path + kDefaultBuiltInTilingPath;
std::string custom_tiling_lib = opp_path + kDefaultCustomTilingPath;
std::string built_in_name = kDefaultBuiltInTilingPath.substr(kPrefixIndex);
std::string custom_name = kDefaultCustomTilingPath.substr(kPrefixIndex);

void *handle_bi = dlopen(built_in_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle_bi == nullptr) {
GELOGW("Failed to dlopen %s!", dlerror());
} else {
handles_[built_in_name] = handle_bi;
}

void *handle_ct = dlopen(custom_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle_ct == nullptr) {
GELOGW("Failed to dlopen %s!", dlerror());
} else {
handles_[custom_name] = handle_ct;
}
}

} // namespace ge

src/ge/graph/passes/identify_reference_pass.h → src/ge/common/ge/op_tiling_manager.h View File

@@ -14,16 +14,25 @@
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_
#define GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_
#ifndef GE_COMMON_GE_OP_TILING_MANAGER_H_
#define GE_COMMON_GE_OP_TILING_MANAGER_H_

#include "graph/passes/base_pass.h"
#include <map>

namespace ge {
class IdentifyReferencePass : public BaseNodePass {
using SoToHandleMap = std::map<std::string, void *>;

class OpTilingManager {
public:
Status Run(NodePtr &node) override;
OpTilingManager() = default;
~OpTilingManager();
void LoadSo();

private:
static std::string GetPath();
void ClearHandles() noexcept;
SoToHandleMap handles_;
};
} // namespace ge

#endif // GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_
#endif // GE_COMMON_GE_OP_TILING_MANAGER_H_

+ 2
- 101
src/ge/common/helper/model_helper.cc View File

@@ -17,6 +17,7 @@
#include "framework/common/helper/model_helper.h"

#include "common/ge/ge_util.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/log.h"
#include "framework/common/util.h"
#include "framework/common/debug/ge_log.h"
@@ -267,6 +268,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c
}
auto partition_table = reinterpret_cast<ModelPartitionTable *>(model_addr_tmp_);
if (partition_table->num == kOriginalOmPartitionNum) {
model_addr_tmp_ = nullptr;
GELOGE(FAILED, "om model is error,please use executable om model");
return FAILED;
}
@@ -390,107 +392,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo
return out_model;
}

// Transit func for model to ge_model. It will be removed when load and build support ge_model in future
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransModelToGeModel(const ModelPtr &model,
GeModelPtr &ge_model) {
if (model == nullptr) {
GELOGE(FAILED, "Model is null");
return FAILED;
}
ge_model = ge::MakeShared<ge::GeModel>();
GE_CHECK_NOTNULL(ge_model);
ge_model->SetGraph(model->GetGraph());
ge_model->SetName(model->GetName());
ge_model->SetVersion(model->GetVersion());
ge_model->SetPlatformVersion(model->GetPlatformVersion());
ge_model->SetAttr(model->MutableAttrMap());

// Copy weight info
auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph());
// ge::Buffer weight;
ge::Buffer weight;
(void)ge::AttrUtils::GetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, weight);
ge_model->SetWeight(weight);
// Copy task info
if (model->HasAttr(MODEL_ATTR_TASKS)) {
ge::Buffer task_buffer;
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetZeroCopyBytes(model, MODEL_ATTR_TASKS, task_buffer), FAILED,
"Get bytes failed.");

std::shared_ptr<ModelTaskDef> task = ge::MakeShared<ModelTaskDef>();
GE_CHECK_NOTNULL(task);
GE_IF_BOOL_EXEC(task_buffer.GetData() == nullptr, GELOGE(FAILED, "Get data fail"); return FAILED);
GE_IF_BOOL_EXEC(task_buffer.GetSize() == 0, GELOGE(FAILED, "Get size fail"); return FAILED);

GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_buffer.GetData(), static_cast<int>(task_buffer.GetSize()), task.get()),
return INTERNAL_ERROR, "ReadProtoFromArray failed.");

ge_model->SetModelTaskDef(task);
}
// Copy tbe kernel info
// TBEKernelStore kernel_store;
TBEKernelStore kernel_store;
if (compute_graph != nullptr && compute_graph->GetDirectNodesSize() != 0) {
for (const ge::NodePtr &n : compute_graph->GetDirectNode()) {
auto node_op_desc = n->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);
TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue);
kernel_store.AddTBEKernel(tbe_kernel);
GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str());
}
}
if (!kernel_store.Build()) {
GELOGE(FAILED, "TBE Kernels store build failed!");
return FAILED;
}
ge_model->SetTBEKernelStore(kernel_store);

return SUCCESS;
}

// trasit func for ge_model to Model. will be removed when load and build support ge_model in future
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransGeModelToModel(const GeModelPtr &ge_model,
ModelPtr &model) {
if (ge_model == nullptr) {
GELOGE(FAILED, "Ge_model is null");
return FAILED;
}
model = ge::MakeShared<ge::Model>();
GE_CHECK_NOTNULL(model);
model->SetGraph(ge_model->GetGraph());
model->SetName(ge_model->GetName());
model->SetVersion(ge_model->GetVersion());
model->SetPlatformVersion(ge_model->GetPlatformVersion());
model->SetAttr(ge_model->MutableAttrMap());
// Copy weight info
auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph());
bool ret = ge::AttrUtils::SetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, ge_model->GetWeight());
if (!ret) {
GELOGE(FAILED, "Copy weight buffer failed!");
return FAILED;
}
// Copy task info
std::shared_ptr<ModelTaskDef> model_task = ge_model->GetModelTaskDefPtr();

if (model_task != nullptr) {
int size = model_task->ByteSize();
ge::Buffer buffer(static_cast<size_t>(size));
if (buffer.GetSize() == 0) {
GELOGE(MEMALLOC_FAILED, "alloc model attr task buffer failed!");
return MEMALLOC_FAILED;
}
// no need to check value
(void)model_task->SerializePartialToArray(buffer.GetData(), size);
ret = ge::AttrUtils::SetZeroCopyBytes(model, MODEL_ATTR_TASKS, std::move(buffer));
if (!ret) {
GELOGE(FAILED, "Copy task buffer failed!");
return FAILED;
}
}
return SUCCESS;
}

Status ModelHelper::ReleaseLocalModelData() noexcept {
Status result = SUCCESS;
if (model_addr_tmp_ != nullptr) {


+ 1
- 1
src/ge/common/math/fp16_math.h View File

@@ -92,5 +92,5 @@ fp16_t max(fp16_t fp1, fp16_t fp2);
/// @brief Calculate the minimum fp16_t of fp1 and fp2
/// @return Returns minimum fp16_t of fp1 and fp2
fp16_t min(fp16_t fp1, fp16_t fp2);
}; // namespace ge
} // namespace ge
#endif // GE_COMMON_MATH_FP16_MATH_H_

+ 0
- 2
src/ge/common/math_util.h View File

@@ -27,7 +27,6 @@
#include "mmpa/mmpa_api.h"

namespace ge {

/**
* @ingroup domi_calibration
* @brief Initializes an input array to a specified value
@@ -67,7 +66,6 @@ Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) {
}
return SUCCESS;
}

} // end namespace ge

#endif // GE_COMMON_MATH_UTIL_H_

+ 3
- 4
src/ge/common/model_saver.cc View File

@@ -60,8 +60,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi
mode_t mode = S_IRUSR | S_IWUSR;
int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode);
if (fd == EN_ERROR || fd == EN_INVALID_PARAM) {
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"filepath", "errMsg"}, {file_path, strerror(errno)});
GELOGE(FAILED, "Open file failed. file path : %s, %s", file_path, strerror(errno));
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)});
GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno));
return FAILED;
}
const char *model_char = model_str.c_str();
@@ -69,8 +69,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi
// Write data to file
mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len);
if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) {
ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"mmpa_ret", "errMsg"},
{std::to_string(mmpa_ret), strerror(errno)});
ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {file_path, strerror(errno)});
// Need to both print the error info of mmWrite and mmClose, so return ret after mmClose
GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno));
ret = FAILED;


+ 13
- 7
src/ge/common/profiling/profiling_manager.cc View File

@@ -336,16 +336,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin

std::string data;
for (const auto &task : task_desc_info) {
std::string model_name = task.model_name;
std::string op_name = task.op_name;
uint32_t block_dim = task.block_dim;
uint32_t task_id = task.task_id;
uint32_t stream_id = task.stream_id;
data = op_name.append(" ").append(std::to_string(block_dim)
.append(" ")
.append(std::to_string(task_id))
.append(" ")
.append(std::to_string(stream_id))
.append("\n"));
data = model_name.append(" ").append(op_name).append(" ").append(std::to_string(block_dim)
.append(" ")
.append(std::to_string(task_id))
.append(" ")
.append(std::to_string(stream_id))
.append("\n"));

Msprof::Engine::ReporterData reporter_data{};
reporter_data.deviceId = device_id;
@@ -376,7 +377,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin

std::string data;
for (const auto &graph : compute_graph_desc_info) {
data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type);
data.append("model_name:")
.append(graph.model_name)
.append(" op_name:")
.append(graph.op_name)
.append(" op_type:")
.append(graph.op_type);
for (size_t i = 0; i < graph.input_format.size(); ++i) {
data.append(" input_id:")
.append(std::to_string(i))


+ 200
- 120
src/ge/common/properties_manager.cc View File

@@ -20,15 +20,204 @@
#include <cstdio>
#include <fstream>

#include "common/ge/ge_util.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "framework/common/ge_types.h"
#include "framework/common/types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_context.h"
#include "graph/utils/attr_utils.h"

namespace ge {
namespace {
const string kEnableFlag = "1";

const uint32_t kAicoreOverflow = (0x1 << 0);
const uint32_t kAtomicOverflow = (0x1 << 1);
const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow);
} // namespace

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) {
CopyFrom(other);
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties::operator=(
const DumpProperties &other) {
CopyFrom(other);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOptions() {
enable_dump_.clear();
enable_dump_debug_.clear();
dump_path_.clear();
dump_step_.clear();
dump_mode_.clear();
is_op_debug_ = false;
op_debug_mode_ = 0;

string enable_dump;
(void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump);
enable_dump_ = enable_dump;

string enable_dump_debug;
(void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug);
enable_dump_debug_ = enable_dump_debug;

if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) {
string dump_path;
if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) {
if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') {
dump_path = dump_path + "/";
}
dump_path = dump_path + CurrentTimeInStr() + "/";
GELOGI("Get dump path %s successfully", dump_path.c_str());
SetDumpPath(dump_path);
} else {
GELOGW("DUMP_PATH is not set");
}
}

if (enable_dump_ == kEnableFlag) {
string dump_step;
if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) {
GELOGD("Get dump step %s successfully", dump_step.c_str());
SetDumpStep(dump_step);
}
string dump_mode;
if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) {
GELOGD("Get dump mode %s successfully", dump_mode.c_str());
SetDumpMode(dump_mode);
}
AddPropertyValue(DUMP_ALL_MODEL, {});
}

SetDumpDebugOptions();
}

// The following is the new dump scenario of the fusion operator
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropertyValue(
const std::string &model, const std::set<std::string> &layers) {
for (const std::string &layer : layers) {
GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str());
}

model_dump_properties_map_[model] = layers;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::DeletePropertyValue(const std::string &model) {
auto iter = model_dump_properties_map_.find(model);
if (iter != model_dump_properties_map_.end()) {
model_dump_properties_map_.erase(iter);
}
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> DumpProperties::GetAllDumpModel() const {
std::set<std::string> model_list;
for (auto &iter : model_dump_properties_map_) {
model_list.insert(iter.first);
}

return model_list;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> DumpProperties::GetPropertyValue(
const std::string &model) const {
auto iter = model_dump_properties_map_.find(model);
if (iter != model_dump_properties_map_.end()) {
return iter->second;
}
return {};
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNeedDump(
const std::string &model, const std::string &om_name, const std::string &op_name) const {
// if dump all
if (model_dump_properties_map_.find(DUMP_ALL_MODEL) != model_dump_properties_map_.end()) {
return true;
}

// if this model need dump
auto om_name_iter = model_dump_properties_map_.find(om_name);
auto model_name_iter = model_dump_properties_map_.find(model);
if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) {
// if no dump layer info, dump all layer in this model
auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter;
if (model_iter->second.empty()) {
return true;
}

return model_iter->second.find(op_name) != model_iter->second.end();
}

GELOGD("Model %s is not seated to be dump.", model.c_str());
return false;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpPath(const std::string &path) {
dump_path_ = path;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpPath() const { return dump_path_; }

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStep(const std::string &step) {
dump_step_ = step;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpStep() const { return dump_step_; }

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpMode(const std::string &mode) {
dump_mode_ = mode;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpMode() const { return dump_mode_; }

void DumpProperties::CopyFrom(const DumpProperties &other) {
if (&other != this) {
enable_dump_ = other.enable_dump_;
enable_dump_debug_ = other.enable_dump_debug_;
dump_path_ = other.dump_path_;
dump_step_ = other.dump_step_;
dump_mode_ = other.dump_mode_;

model_dump_properties_map_ = other.model_dump_properties_map_;
is_op_debug_ = other.is_op_debug_;
op_debug_mode_ = other.op_debug_mode_;
}
}

void DumpProperties::SetDumpDebugOptions() {
if (enable_dump_debug_ == kEnableFlag) {
string dump_debug_mode;
if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) {
GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str());
} else {
GELOGW("Dump debug mode is not set.");
return;
}

if (dump_debug_mode == OP_DEBUG_AICORE) {
GELOGD("ge.exec.dumpDebugMode=aicore_overflow, op debug is open.");
is_op_debug_ = true;
op_debug_mode_ = kAicoreOverflow;
} else if (dump_debug_mode == OP_DEBUG_ATOMIC) {
GELOGD("ge.exec.dumpDebugMode=atomic_overflow, op debug is open.");
is_op_debug_ = true;
op_debug_mode_ = kAtomicOverflow;
} else if (dump_debug_mode == OP_DEBUG_ALL) {
GELOGD("ge.exec.dumpDebugMode=all, op debug is open.");
is_op_debug_ = true;
op_debug_mode_ = kAllOverflow;
} else {
GELOGW("ge.exec.dumpDebugMode is invalid.");
}
} else {
GELOGI("ge.exec.enableDumpDebug is false or is not set.");
}
}

PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {}
PropertiesManager::~PropertiesManager() {}

@@ -159,131 +348,22 @@ PropertiesManager::GetPropertyMap() {

// Set separator
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetPropertyDelimiter(const std::string &de) {
std::lock_guard<std::mutex> lock(mutex_);
delimiter = de;
}

// The following is the new dump scenario of the fusion operator
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::AddDumpPropertyValue(
const std::string &model, const std::set<std::string> &layers) {
for (const std::string &layer : layers) {
GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str());
}

std::lock_guard<std::mutex> lock(dump_mutex_);
model_dump_properties_map_[model] = layers;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::DeleteDumpPropertyValue(
const std::string &model) {
std::lock_guard<std::mutex> lock(dump_mutex_);
auto iter = model_dump_properties_map_.find(model);
if (iter != model_dump_properties_map_.end()) {
model_dump_properties_map_.erase(iter);
}
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::ClearDumpPropertyValue() {
std::lock_guard<std::mutex> lock(dump_mutex_);
model_dump_properties_map_.clear();
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> PropertiesManager::GetAllDumpModel() {
std::set<std::string> model_list;
std::lock_guard<std::mutex> lock(dump_mutex_);
for (auto &iter : model_dump_properties_map_) {
model_list.insert(iter.first);
}

return model_list;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> PropertiesManager::GetDumpPropertyValue(
const std::string &model) {
std::lock_guard<std::mutex> lock(dump_mutex_);
auto iter = model_dump_properties_map_.find(model);
if (iter != model_dump_properties_map_.end()) {
return iter->second;
}
return {};
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::IsLayerNeedDump(const std::string &model,
const std::string &om_name,
const std::string &op_name) {
std::lock_guard<std::mutex> lock(dump_mutex_);
// if dump all
if (model_dump_properties_map_.find(ge::DUMP_ALL_MODEL) != model_dump_properties_map_.end()) {
return true;
}

// if this model need dump
auto om_name_iter = model_dump_properties_map_.find(om_name);
auto model_name_iter = model_dump_properties_map_.find(model);
if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) {
// if no dump layer info, dump all layer in this model
auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter;
if (model_iter->second.empty()) {
return true;
}

return model_iter->second.find(op_name) != model_iter->second.end();
}

GELOGD("Model %s is not seated to be dump.", model.c_str());
return false;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &PropertiesManager::GetDumpProperties(
uint64_t session_id) {
std::lock_guard<std::mutex> lock(mutex_);
// If session_id is not found in dump_properties_map_, operator[] will insert one.
return dump_properties_map_[session_id];
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::QueryModelDumpStatus(
const std::string &model) {
std::lock_guard<std::mutex> lock(dump_mutex_);
auto iter = model_dump_properties_map_.find(model);
if (iter != model_dump_properties_map_.end()) {
return true;
} else if (model_dump_properties_map_.find(ge::DUMP_ALL_MODEL) != model_dump_properties_map_.end()) {
return true;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::RemoveDumpProperties(uint64_t session_id) {
std::lock_guard<std::mutex> lock(mutex_);
auto iter = dump_properties_map_.find(session_id);
if (iter != dump_properties_map_.end()) {
dump_properties_map_.erase(iter);
}
return false;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpOutputModel(
const std::string &output_mode) {
std::lock_guard<std::mutex> lock(dump_mutex_);
this->output_mode_ = output_mode;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpOutputModel() {
std::lock_guard<std::mutex> lock(dump_mutex_);
return this->output_mode_;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpOutputPath(
const std::string &output_path) {
std::lock_guard<std::mutex> lock(dump_mutex_);
this->output_path_ = output_path;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpOutputPath() {
std::lock_guard<std::mutex> lock(dump_mutex_);
return this->output_path_;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpStep(const std::string &dump_step) {
std::lock_guard<std::mutex> lock(dump_mutex_);
this->dump_step_ = dump_step;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpStep() {
std::lock_guard<std::mutex> lock(dump_mutex_);
return this->dump_step_;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpMode(const std::string &dump_mode) {
std::lock_guard<std::mutex> lock(dump_mutex_);
this->dump_mode_ = dump_mode;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpMode() {
std::lock_guard<std::mutex> lock(dump_mutex_);
return this->dump_mode_;
}
} // namespace ge

+ 47
- 21
src/ge/common/properties_manager.h View File

@@ -32,6 +32,50 @@ static const char *USE_FUSION __attribute__((unused)) = "FMK_USE_FUSION";
static const char *TIMESTAT_ENABLE __attribute__((unused)) = "DAVINCI_TIMESTAT_ENABLE";
static const char *ANNDROID_DEBUG __attribute__((unused)) = "ANNDROID_DEBUG";

class DumpProperties {
public:
DumpProperties() = default;
~DumpProperties() = default;
DumpProperties(const DumpProperties &dump);
DumpProperties &operator=(const DumpProperties &dump);

void InitByOptions();

void AddPropertyValue(const std::string &model, const std::set<std::string> &layers);
void DeletePropertyValue(const std::string &model);

std::set<std::string> GetAllDumpModel() const;
std::set<std::string> GetPropertyValue(const std::string &model) const;
bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name) const;

void SetDumpPath(const std::string &path);
std::string GetDumpPath() const;

void SetDumpStep(const std::string &step);
std::string GetDumpStep() const;

void SetDumpMode(const std::string &mode);
std::string GetDumpMode() const;

bool IsOpDebugOpen() const { return is_op_debug_; }
uint32_t GetOpDebugMode() const { return op_debug_mode_; }

private:
void CopyFrom(const DumpProperties &other);
void SetDumpDebugOptions();

string enable_dump_;
string enable_dump_debug_;

std::string dump_path_;
std::string dump_step_;
std::string dump_mode_;
std::map<std::string, std::set<std::string>> model_dump_properties_map_;

bool is_op_debug_ = false;
uint32_t op_debug_mode_ = 0;
};

class PropertiesManager {
public:
// Singleton
@@ -81,21 +125,8 @@ class PropertiesManager {
*/
void SetPropertyDelimiter(const std::string &de);

void AddDumpPropertyValue(const std::string &model, const std::set<std::string> &layers);
std::set<std::string> GetAllDumpModel();
std::set<std::string> GetDumpPropertyValue(const std::string &model);
bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name);
void DeleteDumpPropertyValue(const std::string &model);
void ClearDumpPropertyValue();
bool QueryModelDumpStatus(const std::string &model);
void SetDumpOutputModel(const std::string &output_model);
std::string GetDumpOutputModel();
void SetDumpOutputPath(const std::string &output_path);
std::string GetDumpOutputPath();
void SetDumpStep(const std::string &dump_step);
std::string GetDumpStep();
void SetDumpMode(const std::string &dump_mode);
std::string GetDumpMode();
DumpProperties &GetDumpProperties(uint64_t session_id);
void RemoveDumpProperties(uint64_t session_id);

private:
// Private construct, destructor
@@ -119,12 +150,7 @@ class PropertiesManager {
std::map<std::string, std::string> properties_map_;
std::mutex mutex_;

std::string output_mode_;
std::string output_path_;
std::string dump_step_;
std::string dump_mode_;
std::map<std::string, std::set<std::string>> model_dump_properties_map_; // model_dump_layers_map_
std::mutex dump_mutex_;
std::map<uint64_t, DumpProperties> dump_properties_map_;
};
} // namespace ge



+ 0
- 1
src/ge/common/tbe_kernel_store.h View File

@@ -28,7 +28,6 @@
#include "graph/op_kernel_bin.h"

namespace ge {

using TBEKernel = ge::OpKernelBin;
using TBEKernelPtr = std::shared_ptr<ge::OpKernelBin>;



+ 102
- 93
src/ge/common/types.cc View File

@@ -26,6 +26,11 @@ const std::string DUMP_LAYER = "layer";
const std::string DUMP_FILE_PATH = "path";
const std::string DUMP_MODE = "dump_mode";

// op debug mode
const std::string OP_DEBUG_AICORE = "aicore_overflow";
const std::string OP_DEBUG_ATOMIC = "atomic_overflow";
const std::string OP_DEBUG_ALL = "all";

const int DEFAULT_FORMAT = static_cast<const int>(ge::FORMAT_NCHW);
// Supported public property names
const std::string PROP_OME_START_TIME = "ome_start_time"; // start time
@@ -277,8 +282,8 @@ REGISTER_OPTYPE_DEFINE(GETSPAN, "GetSpan");
REGISTER_OPTYPE_DEFINE(STOPGRADIENT, "StopGradient");
REGISTER_OPTYPE_DEFINE(PREVENTGRADIENT, "PreventGradient");
REGISTER_OPTYPE_DEFINE(GUARANTEECONST, "GuaranteeConst");
REGISTER_OPTYPE_DEFINE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs")
REGISTER_OPTYPE_DEFINE(BROADCASTARGS, "BroadcastArgs")
REGISTER_OPTYPE_DEFINE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs");
REGISTER_OPTYPE_DEFINE(BROADCASTARGS, "BroadcastArgs");
REGISTER_OPTYPE_DEFINE(CONFUSIONMATRIX, "ConfusionMatrix");
REGISTER_OPTYPE_DEFINE(RANK, "Rank");
REGISTER_OPTYPE_DEFINE(PLACEHOLDER, "PlaceHolder");
@@ -286,6 +291,7 @@ REGISTER_OPTYPE_DEFINE(END, "End");
REGISTER_OPTYPE_DEFINE(BASICLSTMCELL, "BasicLSTMCell");
REGISTER_OPTYPE_DEFINE(GETNEXT, "GetNext");
REGISTER_OPTYPE_DEFINE(INITDATA, "InitData");
REGISTER_OPTYPE_DEFINE(REFIDENTITY, "RefIdentity");

/***************Ann special operator*************************/
REGISTER_OPTYPE_DEFINE(ANN_MEAN, "AnnMean");
@@ -479,72 +485,72 @@ const uint64_t ALLOC_MEMORY_MAX_SIZE = 536870912; // Max size of 512M.
#endif

///
///@brief Magic number of model file
/// @brief Magic number of model file
///
const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number

///
///@brief Model head length
/// @brief Model head length
///
const uint32_t MODEL_FILE_HEAD_LEN = 256;

///
///@ingroup domi_omg
///@brief Input node type
/// @ingroup domi_omg
/// @brief Input node type
///
const std::string INPUT_TYPE = "Input";

///
///@ingroup domi_omg
///@brief AIPP label, label AIPP conv operator
/// @ingroup domi_omg
/// @brief AIPP label, label AIPP conv operator
///
const std::string AIPP_CONV_FLAG = "Aipp_Conv_Flag";

///
///@ingroup domi_omg
///@brief AIPP label, label aipp data operator
/// @ingroup domi_omg
/// @brief AIPP label, label aipp data operator
///
const std::string AIPP_DATA_FLAG = "Aipp_Data_Flag";

///
///@ingroup domi_omg
///@brief Record the w dimension of model input corresponding to dynamic AIPP
/// @ingroup domi_omg
/// @brief Record the w dimension of model input corresponding to dynamic AIPP
///
const std::string AIPP_RELATED_DATA_DIM_W = "aipp_related_data_dim_w";

///
///@ingroup domi_omg
///@brief Record the H dimension of model input corresponding to dynamic AIPP
/// @ingroup domi_omg
/// @brief Record the H dimension of model input corresponding to dynamic AIPP
///
const std::string AIPP_RELATED_DATA_DIM_H = "aipp_related_data_dim_h";

///
///@ingroup domi_omg
///@brief The tag of the data operator. Mark this input to the dynamic AIPP operator
/// @ingroup domi_omg
/// @brief The tag of the data operator. Mark this input to the dynamic AIPP operator
///
const std::string INPUT_TO_DYNAMIC_AIPP = "input_to_dynamic_aipp";

///
///@ingroup domi_omg
///@brief DATA node type
/// @ingroup domi_omg
/// @brief DATA node type
///
const std::string DATA_TYPE = "Data";

///
///@ingroup domi_omg
///@brief DATA node type
/// @ingroup domi_omg
/// @brief DATA node type
///
const std::string AIPP_DATA_TYPE = "AippData";

///
///@ingroup domi_omg
///@brief Frame operator type
/// @ingroup domi_omg
/// @brief Frame operator type
///
const std::string FRAMEWORK_OP_TYPE = "FrameworkOp";

///
///@ingroup domi_omg
///@brief Data node type
/// @ingroup domi_omg
/// @brief Data node type
///
const std::string ANN_DATA_TYPE = "AnnData";
const std::string ANN_NETOUTPUT_TYPE = "AnnNetOutput";
@@ -552,136 +558,139 @@ const std::string ANN_DEPTHCONV_TYPE = "AnnDepthConv";
const std::string ANN_CONV_TYPE = "AnnConvolution";
const std::string ANN_FC_TYPE = "AnnFullConnection";
///
///@ingroup domi_omg
///@brief Convolution node type
/// @ingroup domi_omg
/// @brief Convolution node type
///
const std::string NODE_NAME_NET_OUTPUT = "Node_Output";

const std::string NODE_NAME_END_GRAPH = "Node_EndGraph";

const std::string NODE_NAME_OP_DEBUG = "Node_OpDebug";
const std::string OP_TYPE_OP_DEBUG = "Opdebug";

///
///@ingroup domi_omg
///@brief Convolution node type
/// @ingroup domi_omg
/// @brief Convolution node type
///
const std::string OP_TYPE_CONVOLUTION = "Convolution";
///
///@ingroup domi_omg
///@brief Add convolution node name to AIPP
/// @ingroup domi_omg
/// @brief Add convolution node name to AIPP
///
const std::string AIPP_CONV_OP_NAME = "aipp_conv_op";
///
///@ingroup domi_omg
///@brief Operator configuration item separator
/// @ingroup domi_omg
/// @brief Operator configuration item separator
///
const std::string OP_CONF_DELIMITER = ":";

///
///@ingroup domi_omg
///@brief attr value name
/// @ingroup domi_omg
/// @brief attr value name
///
const std::string ATTR_NAME_VALUE1 = "value1";

///
///@ingroup domi_omg
///@brief attr value name, 6d_2_4d C
/// @ingroup domi_omg
/// @brief attr value name, 6d_2_4d C
///
const std::string ATTR_NAME_INPUT_CVALUE = "input_cvalue";

///
///@ingroup domi_omg
///@brief alpha default value
/// @ingroup domi_omg
/// @brief alpha default value
///
const float ALPHA_DEFAULT_VALUE = 1.0;

///
///@ingroup domi_omg
///@brief beta default value
/// @ingroup domi_omg
/// @brief beta default value
///
const float BETA_DEFAULT_VALUE = 0.0;

///
///@ingroup domi_omg
///@brief coef default value
/// @ingroup domi_omg
/// @brief coef default value
///
const float COEF_DEFAULT_VALUE = 0.0;

///
///@ingroup domi_omg
///@brief Relu6 coef value
/// @ingroup domi_omg
/// @brief Relu6 coef value
///
const float RELU6_COEF = 6.0;

///
///@ingroup domi_omg
///@brief stride default value
/// @ingroup domi_omg
/// @brief stride default value
///
const uint32_t STRIDE_DEFAULT_VALUE = 1;

///
///@ingroup domi_omg
///@brief pad default value
/// @ingroup domi_omg
/// @brief pad default value
///
const uint32_t PAD_DEFAULT_VALUE = 0;

///
///@ingroup domi_omg
///@brief dilation default value
/// @ingroup domi_omg
/// @brief dilation default value
///
const int DILATION_DEFAULT_VALUE = 1;

///
///@ingroup domi_omg
///@brief kernel default value
/// @ingroup domi_omg
/// @brief kernel default value
///
const uint32_t KERNEL_DEFAULT_VALUE = 0;

///
///@ingroup domi_omg
///@brief defaule convolution group size
/// @ingroup domi_omg
/// @brief defaule convolution group size
///
const uint32_t DEFAULT_CONV_GROUP = 1;

///
///@ingroup domi_omg
///@brief Default deconvolution adj
/// @ingroup domi_omg
/// @brief Default deconvolution adj
///
const uint32_t DEFAULT_DECONV_ADJ = 0;

///
///@ingroup domi_omg
///@brief Represents value 1
/// @ingroup domi_omg
/// @brief Represents value 1
///
const uint32_t NUM_ONE = 1;

///
///@ingroup domi_omg
///@brief spatial dim size default value
/// @ingroup domi_omg
/// @brief spatial dim size default value
///
const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2;

///
///@ingroup domi_omg
///@brief dim extended default value
/// @ingroup domi_omg
/// @brief dim extended default value
///
const int32_t DIM_DEFAULT_VALUE = 1;

///
///@ingroup domi_omg
///@brief The first weight list in opdef is filter
/// @ingroup domi_omg
/// @brief The first weight list in opdef is filter
///
const int32_t WEIGHT_FILTER_INDEX = 0;

///
///@ingroup domi_omg
///@brief The second weight list in opdef is bias
/// @ingroup domi_omg
/// @brief The second weight list in opdef is bias
///
const int32_t WEIGHT_BIAS_INDEX = 1;

const int32_t TENSOR_ND_SUPPORT_SIZE = 8;

///
///@ingroup domi_omg
///@brief NCHW index default value
/// @ingroup domi_omg
/// @brief NCHW index default value
///
const uint32_t NCHW_DIM_N = 0;
const uint32_t NCHW_DIM_C = 1;
@@ -689,8 +698,8 @@ const uint32_t NCHW_DIM_H = 2;
const uint32_t NCHW_DIM_W = 3;

///
///@ingroup domi_omg
///@brief KCHW index default value
/// @ingroup domi_omg
/// @brief KCHW index default value
///
const uint32_t KCHW_DIM_K = 0;
const uint32_t KCHW_DIM_C = 1;
@@ -698,8 +707,8 @@ const uint32_t KCHW_DIM_H = 2;
const uint32_t KCHW_DIM_W = 3;

///
///@ingroup domi_omg
///@brief HWCK index default value
/// @ingroup domi_omg
/// @brief HWCK index default value
///
const uint32_t HWCK_DIM_H = 0;
const uint32_t HWCK_DIM_W = 1;
@@ -707,8 +716,8 @@ const uint32_t HWCK_DIM_C = 2;
const uint32_t HWCK_DIM_K = 3;

///
///@ingroup domi_omg
///@brief NHWC index default value
/// @ingroup domi_omg
/// @brief NHWC index default value
///
const uint32_t NHWC_DIM_N = 0;
const uint32_t NHWC_DIM_H = 1;
@@ -716,8 +725,8 @@ const uint32_t NHWC_DIM_W = 2;
const uint32_t NHWC_DIM_C = 3;

///
///@ingroup domi_omg
///@brief CHWN index default value
/// @ingroup domi_omg
/// @brief CHWN index default value
///
const uint32_t CHWN_DIM_N = 3;
const uint32_t CHWN_DIM_C = 0;
@@ -725,23 +734,23 @@ const uint32_t CHWN_DIM_H = 1;
const uint32_t CHWN_DIM_W = 2;

///
///@ingroup domi_omg
///@brief CHW index default value
/// @ingroup domi_omg
/// @brief CHW index default value
///
const uint32_t CHW_DIM_C = 0;
const uint32_t CHW_DIM_H = 1;
const uint32_t CHW_DIM_W = 2;

///
///@ingroup domi_omg
///@brief HWC index default value
/// @ingroup domi_omg
/// @brief HWC index default value
///
const uint32_t HWC_DIM_H = 0;
const uint32_t HWC_DIM_W = 1;
const uint32_t HWC_DIM_C = 2;
///
///@ingroup domi_omg
///@brief Pad index default value
/// @ingroup domi_omg
/// @brief Pad index default value
///
const uint32_t PAD_H_HEAD = 0;
const uint32_t PAD_H_TAIL = 1;
@@ -749,35 +758,35 @@ const uint32_t PAD_W_HEAD = 2;
const uint32_t PAD_W_TAIL = 3;

///
///@ingroup domi_omg
///@brief window index default value
/// @ingroup domi_omg
/// @brief window index default value
///
const uint32_t WINDOW_H = 0;
const uint32_t WINDOW_W = 1;

///
///@ingroup domi_omg
///@brief stride index default value
/// @ingroup domi_omg
/// @brief stride index default value
///
const uint32_t STRIDE_H = 0;
const uint32_t STRIDE_W = 1;

///
///@ingroup domi_omg
///@brief dilation index default value
/// @ingroup domi_omg
/// @brief dilation index default value
///
const uint32_t DILATION_H = 0;
const uint32_t DILATION_W = 1;

///
///@ingroup domi_omg
///@brief the num of XRBG channel
/// @ingroup domi_omg
/// @brief the num of XRBG channel
///
const uint32_t XRGB_CHN_NUM = 4;

///
///@ingroup domi_omg
///@brief global pooling default value
/// @ingroup domi_omg
/// @brief global pooling default value
///
const bool DEFAULT_GLOBAL_POOLING = false;

@@ -801,4 +810,4 @@ const uint32_t STREAM_SWITCH_INPUT_NUM = 2;

const std::string NODE_NAME_GLOBAL_STEP = "ge_global_step";
const std::string NODE_NAME_GLOBAL_STEP_ASSIGNADD = "global_step_assignadd";
}; // namespace ge
} // namespace ge

+ 31
- 35
src/ge/common/util.cc View File

@@ -56,6 +56,7 @@ const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M
/// The maximum length of the file.
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
const int kMaxFileSizeLimit = INT_MAX;
const char *const kPathValidReason = "The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character";
} // namespace

namespace ge {
@@ -77,7 +78,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co

std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary);
if (!fs.is_open()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"realpath"}, {file});
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"});
GELOGE(ge::FAILED, "Open real path[%s] failed.", file);
return false;
}
@@ -90,7 +91,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co
fs.close();

if (!ret) {
ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"filepath"}, {file});
ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file});
GELOGE(ge::FAILED, "Parse file[%s] failed.", file);
return ret;
}
@@ -114,17 +115,18 @@ long GetFileLength(const std::string &input_file) {

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str());
unsigned long long file_length = 0;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"filepath"}, {input_file});
return -1, "Open file[%s] failed", input_file.c_str());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)});
return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno));

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0),
ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"filepath"}, {input_file});
ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file});
return -1, "File[%s] size is 0, not valid.", input_file.c_str());

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage(
"E10039", {"filepath", "filesize", "maxlen"},
"E19016", {"filepath", "filesize", "maxlen"},
{input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)});
return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, kMaxFileSizeLimit);
return static_cast<long>(file_length);
@@ -219,7 +221,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std::
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret;
}
}
@@ -230,7 +232,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std::
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret;
}
}
@@ -258,16 +260,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch
"incorrect parameter. nullptr == file || nullptr == message");

std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"filepath"}, {file});
return false, "Get path[%s]'s real path failed", file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage(
"E19000", {"path", "errmsg"}, {file, strerror(errno)});
return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno));

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");

std::ifstream fs(real_path.c_str(), std::ifstream::in);

if (!fs.is_open()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10040", {"realpth", "protofile"}, {real_path, file});
ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file});
GELOGE(ge::FAILED, "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(),
file);
return false;
@@ -275,7 +277,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch

google::protobuf::io::IstreamInputStream input(&fs);
bool ret = google::protobuf::TextFormat::Parse(&input, message);
GE_IF_BOOL_EXEC(!ret, ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"protofile"}, {file});
GE_IF_BOOL_EXEC(!ret, ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file});
GELOGE(ret,
"Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, "
"please check whether the file is a valid protobuf format file.",
@@ -360,14 +362,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const
// The specified path is empty
std::map<std::string, std::string> args_map;
if (file_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param});
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param});
GELOGW("Input parameter's value is empty.");
return false;
}
std::string real_path = RealPath(file_path.c_str());
// Unable to get absolute path (does not exist or does not have permission to access)
if (real_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)});
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file_path, strerror(errno)});
GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno));
return false;
}
@@ -380,16 +382,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(real_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path});
return false,
"Input parameter[--%s]'s value[%s] is illegal. The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' "
"and chinese character.",
atc_param.c_str(), real_path.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, real_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason);

// The absolute path points to a file that is not readable
if (access(real_path.c_str(), R_OK) != 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)});
GELOGW("Read path[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno));
ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"file", "errmsg"}, {file_path.c_str(), strerror(errno)});
GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno));
return false;
}

@@ -400,7 +400,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const
const std::string &atc_param) {
// The specified path is empty
if (file_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param});
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param});
GELOGW("Input parameter's value is empty.");
return false;
}
@@ -416,18 +416,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(real_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path});
return false,
"Input parameter[--%s]'s value[%s] is illegal. The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' "
"and chinese character.",
atc_param.c_str(), real_path.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, real_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason);

// File is not readable or writable
if (access(real_path.c_str(), W_OK | F_OK) != 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"realpath", "path", "errmsg"},
{real_path, file_path, strerror(errno)});
GELOGW("Write file[%s] failed, input path is %s, errmsg[%s]", real_path.c_str(), file_path.c_str(),
strerror(errno));
ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {real_path, strerror(errno)});
GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno));
return false;
}
} else {
@@ -445,8 +441,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const
std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos));
// Determine whether the specified path is valid by creating the path
if (CreateDirectory(prefix_path) != 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"path"}, {file_path});
GELOGW("Can not create prefix path for path[%s].", file_path.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {file_path});
GELOGW("Can not create directory[%s].", file_path.c_str());
return false;
}
}


+ 21
- 1
src/ge/engine_manager/dnnengine_manager.cc View File

@@ -24,6 +24,7 @@

#include "common/debug/log.h"
#include "common/ge/ge_util.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "graph/ge_context.h"
#include "init/gelib.h"
@@ -161,6 +162,10 @@ bool DNNEngineManager::IsEngineRegistered(const std::string &name) {
return false;
}

void DNNEngineManager::InitPerformanceStaistic() { checksupport_cost_.clear(); }

const map<string, uint64_t> &DNNEngineManager::GetCheckSupportCost() const { return checksupport_cost_; }

std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) {
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GE_CLI_GE_NOT_INITIALIZED, "DNNEngineManager: op_desc is nullptr");
return "");
@@ -194,15 +199,20 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) {
if (kernel_info_store != kernel_map.end()) {
std::string unsupported_reason;
// It will be replaced by engine' checksupport
uint64_t start_time = GetCurrentTimestap();
if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) {
checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time;
op_desc->SetOpEngineName(it.engine);
op_desc->SetOpKernelLibName(kernel_name);
GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(),
it.engine.c_str(), op_desc->GetName().c_str());
return it.engine;
} else {
checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time;
bool is_custom_op = false;
if ((ge::AttrUtils::GetBool(op_desc, kCustomOpFlag, is_custom_op)) && is_custom_op) {
ErrorManager::GetInstance().ATCReportErrMessage("E13001", {"kernelname", "optype", "opname"},
{kernel_name, op_desc->GetType(), op_desc->GetName()});
GELOGE(FAILED,
"The custom operator registered by the user does not support the logic function delivered by this "
"network. Check support failed, kernel_name is %s, op type is %s, op name is %s",
@@ -221,9 +231,13 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) {
}
}
for (const auto &it : unsupported_reasons) {
ErrorManager::GetInstance().ATCReportErrMessage("E13002", {"optype", "opskernel", "reason"},
{op_desc->GetType(), it.first, it.second});
GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "GetDNNEngineName:Op type %s of ops kernel %s is unsupported, reason:%s",
op_desc->GetType().c_str(), it.first.c_str(), it.second.c_str());
}
ErrorManager::GetInstance().ATCReportErrMessage("E13003", {"opname", "optype"},
{op_desc->GetName(), op_desc->GetType()});
GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "Can't find any supported ops kernel and engine of %s, type is %s",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return "";
@@ -384,7 +398,13 @@ Status DNNEngineManager::ReadJsonFile(const std::string &file_path, JsonHandle h
return FAILED;
}

ifs >> *json_file;
try {
ifs >> *json_file;
} catch (const json::exception &e) {
GELOGE(FAILED, "Read json file failed");
ifs.close();
return FAILED;
}
ifs.close();
GELOGI("Read json file success");
return SUCCESS;


+ 3
- 0
src/ge/engine_manager/dnnengine_manager.h View File

@@ -63,6 +63,8 @@ class DNNEngineManager {
// If can't find appropriate engine name, return "", report error
string GetDNNEngineName(const OpDescPtr &op_desc);
const map<string, SchedulerConf> &GetSchedulers() const;
const map<string, uint64_t> &GetCheckSupportCost() const;
void InitPerformanceStaistic();

private:
DNNEngineManager();
@@ -78,6 +80,7 @@ class DNNEngineManager {
std::map<std::string, DNNEnginePtr> engines_map_;
std::map<std::string, ge::DNNEngineAttribute> engines_attrs_map_;
std::map<string, SchedulerConf> schedulers_;
std::map<string, uint64_t> checksupport_cost_;
bool init_flag_;
};
} // namespace ge


+ 1
- 1
src/ge/executor/CMakeLists.txt View File

@@ -26,6 +26,7 @@ file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}

file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"ge_executor.cc"
"../common/ge/op_tiling_manager.cc"
"../common/ge/plugin_manager.cc"
"../common/profiling/profiling_manager.cc"
"../graph/execute/graph_execute.cc"
@@ -59,7 +60,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"../graph/load/new_model_manager/task_info/task_info.cc"
"../graph/load/new_model_manager/tbe_handle_store.cc"
"../graph/load/new_model_manager/zero_copy_task.cc"
"../graph/load/output/output.cc"
"../graph/manager/graph_caching_allocator.cc"
"../graph/manager/graph_manager_utils.cc"
"../graph/manager/graph_mem_allocator.cc"


+ 3
- 4
src/ge/executor/ge_executor.cc View File

@@ -452,7 +452,7 @@ Status GeExecutor::RunModel(const ge::RunModelData &input_data, ge::RunModelData

// Get input and output descriptor
Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc,
std::vector<ge::TensorDesc> &output_desc) {
std::vector<ge::TensorDesc> &output_desc, bool new_model_desc) {
GELOGI("get model desc info begin.");
if (!isInit_) {
GELOGE(GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!");
@@ -464,8 +464,8 @@ Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDes
std::vector<uint32_t> input_formats;
std::vector<uint32_t> output_formats;

Status ret =
GraphExecutor::GetInputOutputDescInfo(model_id, input_desc_infos, output_desc_infos, input_formats, output_formats);
Status ret = GraphExecutor::GetInputOutputDescInfo(model_id, input_desc_infos, output_desc_infos, input_formats,
output_formats, new_model_desc);
if (ret != domi::SUCCESS) {
GELOGE(ret, "GetInputOutputDescInfo failed. ret = %u", ret);
return TransferDomiErrorCode(ret);
@@ -854,5 +854,4 @@ Status GeExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index,
GELOGI("GetAllAippInputOutputDims succ.");
return SUCCESS;
}

} // namespace ge

+ 2
- 1
src/ge/executor/module.mk View File

@@ -4,6 +4,7 @@ local_ge_executor_src_files := \
ge_executor.cc \
../common/profiling/profiling_manager.cc \
../common/ge/plugin_manager.cc \
../common/ge/op_tiling_manager.cc \
../graph/load/graph_loader.cc \
../graph/execute/graph_execute.cc \
../omm/csa_interact.cc \
@@ -44,7 +45,6 @@ local_ge_executor_src_files := \
../graph/load/new_model_manager/task_info/end_graph_task_info.cc \
../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \
../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \
../graph/load/output/output.cc \
../single_op/single_op_manager.cc \
../single_op/single_op_model.cc \
../single_op/single_op.cc \
@@ -53,6 +53,7 @@ local_ge_executor_src_files := \
../single_op/task/build_task_utils.cc \
../single_op/task/tbe_task_builder.cc \
../single_op/task/aicpu_task_builder.cc \
../single_op/task/aicpu_kernel_task_builder.cc \
../hybrid/hybrid_davinci_model_stub.cc\

local_ge_executor_c_include := \


+ 30
- 4
src/ge/ge_inference.mk View File

@@ -1,5 +1,5 @@
LOCAL_PATH := $(call my-dir)
include $(LOCAL_PATH)/stub/Makefile
COMMON_LOCAL_SRC_FILES := \
proto/fusion_model.proto \
proto/optimizer_priority.proto \
@@ -32,6 +32,7 @@ COMMON_LOCAL_SRC_FILES := \

GRAPH_MANAGER_LOCAL_SRC_FILES := \
common/ge/plugin_manager.cc\
common/ge/op_tiling_manager.cc\
init/gelib.cc \
session/inner_session.cc \
session/session_manager.cc \
@@ -91,6 +92,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/no_use_reshape_remove_pass.cc \
graph/passes/iterator_op_pass.cc \
graph/passes/atomic_addr_clean_pass.cc \
graph/passes/mark_same_addr_pass.cc \
graph/common/omg_util.cc \
graph/common/bcast.cc \
graph/passes/dimension_compute_pass.cc \
@@ -145,6 +147,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/stop_gradient_pass.cc \
graph/passes/prevent_gradient_pass.cc \
graph/passes/identity_pass.cc \
graph/passes/ref_identity_delete_op_pass.cc \
graph/passes/placeholder_with_default_pass.cc \
graph/passes/snapshot_pass.cc \
graph/passes/guarantee_const_pass.cc \
@@ -153,7 +156,9 @@ OMG_HOST_SRC_FILES := \
graph/passes/folding_pass.cc \
graph/passes/cast_translate_pass.cc \
graph/passes/prune_pass.cc \
graph/passes/switch_op_pass.cc \
graph/passes/merge_to_stream_merge_pass.cc \
graph/passes/switch_to_stream_switch_pass.cc \
graph/passes/attach_stream_label_pass.cc \
graph/passes/multi_batch_pass.cc \
graph/passes/next_iteration_pass.cc \
graph/passes/control_trigger_pass.cc \
@@ -173,7 +178,6 @@ OMG_HOST_SRC_FILES := \
graph/passes/variable_op_pass.cc \
graph/passes/cast_remove_pass.cc \
graph/passes/transpose_transdata_pass.cc \
graph/passes/identify_reference_pass.cc \
graph/passes/hccl_memcpy_pass.cc \
graph/passes/flow_ctrl_pass.cc \
graph/passes/link_gen_mask_nodes_pass.cc \
@@ -199,7 +203,6 @@ OME_HOST_SRC_FILES := \
graph/load/new_model_manager/tbe_handle_store.cc \
graph/load/new_model_manager/cpu_queue_schedule.cc \
graph/load/new_model_manager/zero_copy_task.cc \
graph/load/output/output.cc \
graph/load/new_model_manager/data_dumper.cc \
graph/load/new_model_manager/task_info/task_info.cc \
graph/load/new_model_manager/task_info/event_record_task_info.cc \
@@ -224,6 +227,7 @@ OME_HOST_SRC_FILES := \
single_op/task/build_task_utils.cc \
single_op/task/tbe_task_builder.cc \
single_op/task/aicpu_task_builder.cc \
single_op/task/aicpu_kernel_task_builder.cc \
single_op/single_op.cc \
single_op/single_op_model.cc \
single_op/stream_resource.cc \
@@ -353,6 +357,28 @@ LOCAL_SHARED_LIBRARIES := \
LOCAL_LDFLAGS := -lrt -ldl


include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for host infer
include $(CLEAR_VARS)

LOCAL_MODULE := stub/libge_compiler

LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2
LOCAL_CFLAGS += -DFMK_HOST_INFER -DFMK_SUPPORT_DUMP
ifeq ($(DEBUG), 1)
LOCAL_CFLAGS += -g -O0
endif

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)

LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_ir_build.cc


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for device


+ 36
- 3
src/ge/ge_runner.mk View File

@@ -23,6 +23,7 @@ LIBGE_LOCAL_SRC_FILES := \
common/formats/utils/formats_trans_utils.cc \
common/fp16_t.cc \
common/ge/plugin_manager.cc\
common/ge/op_tiling_manager.cc\
common/helper/model_cache_helper.cc \
common/profiling/profiling_manager.cc \
engine_manager/dnnengine_manager.cc \
@@ -77,7 +78,6 @@ LIBGE_LOCAL_SRC_FILES := \
graph/load/new_model_manager/task_info/task_info.cc \
graph/load/new_model_manager/tbe_handle_store.cc \
graph/load/new_model_manager/zero_copy_task.cc \
graph/load/output/output.cc \
graph/manager/graph_context.cc \
graph/manager/graph_manager.cc \
graph/manager/graph_manager_utils.cc \
@@ -99,6 +99,7 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/aicpu_constant_folding_pass.cc \
graph/passes/assert_pass.cc \
graph/passes/atomic_addr_clean_pass.cc \
graph/passes/mark_same_addr_pass.cc \
graph/partition/dynamic_shape_partition.cc \
graph/passes/base_pass.cc \
graph/passes/cast_remove_pass.cc \
@@ -158,8 +159,8 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/get_original_format_pass.cc \
graph/passes/guarantee_const_pass.cc \
graph/passes/hccl_memcpy_pass.cc \
graph/passes/identify_reference_pass.cc \
graph/passes/identity_pass.cc \
graph/passes/ref_identity_delete_op_pass.cc \
graph/passes/infershape_pass.cc \
graph/passes/isolated_op_remove_pass.cc \
graph/passes/iterator_op_pass.cc \
@@ -191,7 +192,9 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/data_pass.cc \
graph/passes/switch_data_edges_bypass.cc \
graph/passes/switch_logic_remove_pass.cc \
graph/passes/switch_op_pass.cc \
graph/passes/merge_to_stream_merge_pass.cc \
graph/passes/switch_to_stream_switch_pass.cc \
graph/passes/attach_stream_label_pass.cc \
graph/passes/switch_dead_branch_elimination.cc \
graph/passes/replace_transshape_pass.cc \
graph/passes/transop_breadth_fusion_pass.cc \
@@ -230,6 +233,7 @@ LIBGE_LOCAL_SRC_FILES := \
single_op/task/op_task.cc \
single_op/task/tbe_task_builder.cc \
single_op/task/aicpu_task_builder.cc \
single_op/task/aicpu_kernel_task_builder.cc \
hybrid/common/tensor_value.cc \
hybrid/common/npu_memory_allocator.cc \
hybrid/executor/rt_callback_manager.cc \
@@ -239,12 +243,15 @@ LIBGE_LOCAL_SRC_FILES := \
hybrid/executor/hybrid_model_executor.cc \
hybrid/executor/hybrid_model_async_executor.cc \
hybrid/executor/hybrid_execution_context.cc \
hybrid/executor/subgraph_context.cc \
hybrid/executor/subgraph_executor.cc \
hybrid/executor/worker/task_compile_engine.cc \
hybrid/executor/worker/shape_inference_engine.cc \
hybrid/executor/worker/execution_engine.cc \
hybrid/model/hybrid_model.cc \
hybrid/model/hybrid_model_builder.cc \
hybrid/model/node_item.cc \
hybrid/model/graph_item.cc \
hybrid/node_executor/aicore/aicore_node_executor.cc \
hybrid/node_executor/aicore/aicore_op_task.cc \
hybrid/node_executor/aicore/aicore_task_builder.cc \
@@ -253,6 +260,9 @@ LIBGE_LOCAL_SRC_FILES := \
hybrid/node_executor/aicpu/aicpu_node_executor.cc \
hybrid/node_executor/compiledsubgraph/known_node_executor.cc \
hybrid/node_executor/hostcpu/ge_local_node_executor.cc \
hybrid/node_executor/controlop/control_op_executor.cc \
hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \
hybrid/node_executor/hccl/hccl_node_executor.cc \
hybrid/node_executor/node_executor.cc \
hybrid/node_executor/task_context.cc \
hybrid/hybrid_davinci_model.cc \
@@ -338,6 +348,28 @@ LOCAL_SHARED_LIBRARIES += \

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for GeRunner
include $(CLEAR_VARS)

LOCAL_MODULE := stub/libge_runner

LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2
LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD
ifeq ($(DEBUG), 1)
LOCAL_CFLAGS += -g -O0
endif


LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES)

LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

include $(BUILD_HOST_SHARED_LIBRARY)

# add engine_conf.json to host
include $(CLEAR_VARS)
@@ -407,6 +439,7 @@ LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD
LOCAL_CFLAGS += -g -O0

LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES)

LOCAL_SRC_FILES := $(LIBGE_LOCAL_SRC_FILES)
LOCAL_SRC_FILES += $(LIBCLIENT_LOCAL_SRC_FILES)



+ 31
- 0
src/ge/ge_runtime/model_runner.cc View File

@@ -49,6 +49,15 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint
return true;
}

bool ModelRunner::LoadModelComplete(uint32_t model_id) {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
return false;
}
return model_iter->second->LoadComplete();
}

const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
@@ -60,6 +69,28 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
return model_iter->second->GetTaskIdList();
}

const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
static const std::vector<uint32_t> empty_ret;
return empty_ret;
}

return model_iter->second->GetStreamIdList();
}

const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGW("Model id %u not found.", model_id);
static const std::map<std::string, std::shared_ptr<RuntimeInfo>> empty_ret;
return empty_ret;
}

return model_iter->second->GetRuntimeInfoMap();
}

bool ModelRunner::UnloadModel(uint32_t model_id) {
auto iter = runtime_models_.find(model_id);
if (iter != runtime_models_.end()) {


+ 1
- 1
src/ge/ge_runtime/output.cc View File

@@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde
DataBuffer data_buf = rslt->blobs[data_begin + data_count];
bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share);
if (!ret) {
GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]);
GELOGE(FAILED, "Copy data to host error. index: %lu, addr: %p", i, v_input_data_addr_[i]);
return ret;
}
data_index = data_begin + data_count;


+ 57
- 24
src/ge/ge_runtime/runtime_model.cc View File

@@ -28,7 +28,6 @@

namespace ge {
namespace model_runner {

RuntimeModel::~RuntimeModel() {
GELOGI("RuntimeModel destructor start");

@@ -116,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) {
return true;
}

bool RuntimeModel::InitLabel(uint32_t batch_num) {
GELOGI("batch number:%u.", batch_num);
for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) {
rtLabel_t rt_lLabel = nullptr;
rtError_t rt_ret = rtLabelCreate(&rt_lLabel);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret);
return false;
bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) {
GELOGI("batch number:%u.", davinci_model->GetBatchNum());
label_list_.resize(davinci_model->GetBatchNum());
for (auto &task_info : davinci_model->GetTaskInfoList()) {
if (task_info == nullptr) {
GELOGE(PARAM_INVALID, "task_info is null.");
continue;
}

if (task_info->type() != TaskInfoType::LABEL_SET) {
continue;
}
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info);

if (rt_lLabel == nullptr) {
GELOGE(RT_FAILED, "rtLabel is nullptr!");
if (label_set_task_info->stream_id() >= stream_list_.size()) {
GELOGE(PARAM_INVALID, "Invalid stream id.");
return false;
}

label_list_.emplace_back(rt_lLabel);
rtLabel_t rt_label = nullptr;
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret);
return false;
}
label_list_[label_set_task_info->label_id()] = rt_label;
}

return true;
}

@@ -164,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) {
return false;
}

if (!InitLabel(davinci_model->GetBatchNum())) {
if (!InitLabel(davinci_model)) {
return false;
}

@@ -209,20 +219,41 @@ bool RuntimeModel::LoadTask() {
return false;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
if (task->Args() != nullptr) {
std::shared_ptr<RuntimeInfo> runtime_tuple = nullptr;
GE_MAKE_SHARED(runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args()), return false);
auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple);
if (!emplace_ret.second) {
GELOGW("Task name exist:%s", task->task_name().c_str());
}
}
}
if (task_list_.empty()) {
GELOGE(FAILED, "Task list is empty");
return false;
}
GELOGI("Distribute task succ.");

auto rt_ret = rtModelLoadComplete(rt_model_handle_);
GELOGI("LoadTask succ.");
return true;
}

bool RuntimeModel::LoadComplete() {
uint32_t task_id = 0;
uint32_t stream_id = 0;
auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret);
return RT_FAILED;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);

rt_ret = rtModelLoadComplete(rt_model_handle_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret);
return false;
}

GELOGI("LoadTask succ.");
return true;
}

@@ -270,10 +301,14 @@ bool RuntimeModel::Run() {
return false;
}

GELOGI("Run rtModelExecute success");
GELOGI("Run rtModelExecute success, ret = 0x%X", ret);

ret = rtStreamSynchronize(rt_model_stream_);
if (ret != RT_ERROR_NONE) {
if (ret == RT_ERROR_END_OF_SEQUENCE) {
GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret);
return true;
}
GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret);
return false;
}
@@ -433,7 +468,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
}

if (constant->output_tensors[0].size < constant->weight_data.size()) {
GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size,
GELOGE(PARAM_INVALID, "Output size:%u less than weight data size:%zu", constant->output_tensors[0].size,
constant->weight_data.size());
return false;
}
@@ -448,11 +483,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero
/// and that of unknown shape is zero too.
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not.
int64_t elem_num = constant->weight_tensors[0].GetShapeSize();
if (elem_num == 0 && constant->weight_tensors[0].size == 0) {
elem_num = 1;
}

int64_t elem_num =
(constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize();
if (constant->weight_data.size() < sizeof(uint64_t)) {
GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)");
return false;
@@ -495,5 +527,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp

const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }

const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
} // namespace model_runner
} // namespace ge

+ 7
- 2
src/ge/ge_runtime/runtime_model.h View File

@@ -27,7 +27,7 @@

namespace ge {
namespace model_runner {
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class Task;
class RuntimeModel {
public:
@@ -35,7 +35,10 @@ class RuntimeModel {
~RuntimeModel();

bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model);
bool LoadComplete();
const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const;
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }
bool Run();
bool CopyInputData(const InputData &input_data);
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
@@ -48,7 +51,7 @@ class RuntimeModel {
bool LoadTask();
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model);
bool InitEvent(uint32_t event_num);
bool InitLabel(uint32_t batch_num);
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model);
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model);
@@ -77,6 +80,8 @@ class RuntimeModel {
std::vector<std::shared_ptr<OpInfo>> constant_info_list_{};

std::vector<uint32_t> task_id_list_{};
std::vector<uint32_t> stream_id_list_{};
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_;
};

} // namespace model_runner


+ 9
- 5
src/ge/ge_runtime/task/aicpu_task.cc View File

@@ -85,11 +85,15 @@ bool AicpuTask::Distribute() {
return false;
}

GELOGI("Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s.", args_size,
io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data());
rt_ret = rtCpuKernelLaunch(reinterpret_cast<const void *>(task_info_->so_name().data()),
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, args_size,
nullptr, stream_);
input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset);

auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
GELOGI(
"Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.",
args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag);
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()),
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_,
args_size, nullptr, stream_, dump_flag);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;


+ 6
- 0
src/ge/ge_runtime/task/aicpu_task.h View File

@@ -18,6 +18,7 @@
#define GE_GE_RUNTIME_TASK_AICPU_TASK_H_

#include <memory>
#include <string>
#include "ge_runtime/task/task.h"

namespace ge {
@@ -30,12 +31,17 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {

bool Distribute() override;

void *Args() override { return input_output_addr_; }

std::string task_name() const override { return task_info_->op_name(); }

private:
static void ReleaseRtMem(void **ptr) noexcept;

std::shared_ptr<AicpuTaskInfo> task_info_;
void *stream_;
void *args_;
void *input_output_addr_;
};
} // namespace model_runner
} // namespace ge


+ 0
- 1
src/ge/ge_runtime/task/hccl_task.cc View File

@@ -115,7 +115,6 @@ bool HcclTask::Distribute() {
rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
(void)rtStreamDestroy(stream);
return false;
}



+ 70
- 0
src/ge/ge_runtime/task/label_goto_task.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}

LabelGotoTask::~LabelGotoTask() {}

bool LabelGotoTask::Distribute() {
GELOGI("LabelGotoTask Distribute start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelGotoEx(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo);

} // namespace model_runner
} // namespace ge

+ 41
- 0
src/ge/ge_runtime/task/label_goto_task.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
public:
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info);

~LabelGotoTask() override;

bool Distribute() override;

private:
std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_

+ 70
- 0
src/ge/ge_runtime/task/label_set_task.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/label_set_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info)
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}

LabelSetTask::~LabelSetTask() {}

bool LabelSetTask::Distribute() {
GELOGI("LabelSetTask Distribute start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelSet(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo);

} // namespace model_runner
} // namespace ge

+ 41
- 0
src/ge/ge_runtime/task/label_set_task.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> {
public:
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info);

~LabelSetTask() override;

bool Distribute() override;

private:
std::shared_ptr<LabelSetTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_

+ 131
- 0
src/ge/ge_runtime/task/label_switch_task.cc View File

@@ -0,0 +1,131 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/label_switch_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
const std::shared_ptr<LabelSwitchTaskInfo> &task_info)
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
all_label_resource_(),
label_info_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}

all_label_resource_ = model_context.label_list();
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
if (stream_id >= stream_list.size()) {
GELOGW("Stream id invalid.");
return;
}
stream_ = stream_list[stream_id];
}

LabelSwitchTask::~LabelSwitchTask() {
if (label_info_ != nullptr) {
rtError_t rt_ret = rtFree(label_info_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret);
}
label_info_ = nullptr;
}
}

bool LabelSwitchTask::Distribute() {
GELOGI("LabelSwitchTask Distribute start.");
if (!CheckParamValid()) {
return false;
}

const std::vector<uint32_t> &label_index_list = task_info_->label_list();
std::vector<void *> label_list(task_info_->label_size(), nullptr);

for (size_t i = 0; i < task_info_->label_size(); ++i) {
uint32_t label_index = label_index_list[i];
if (label_index >= all_label_resource_.size()) {
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index,
all_label_resource_.size());
return false;
}
label_list[i] = all_label_resource_[label_index];
GELOGI("Case %zu: label id %zu.", i, label_index);
}

uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size();
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

bool LabelSwitchTask::CheckParamValid() {
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}

if (task_info_->label_list().empty()) {
GELOGE(PARAM_INVALID, "label_list is empty.");
return false;
}

if (task_info_->label_size() != task_info_->label_list().size()) {
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(),
task_info_->label_size());
return false;
}

if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) {
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size());
return false;
}

if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
}

return true;
}

REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo);

} // namespace model_runner
} // namespace ge

+ 44
- 0
src/ge/ge_runtime/task/label_switch_task.h View File

@@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> {
public:
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info);

~LabelSwitchTask() override;

bool Distribute() override;

private:
bool CheckParamValid();

std::shared_ptr<LabelSwitchTaskInfo> task_info_;
void *stream_;
std::vector<void *> all_label_resource_;
void *label_info_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_

+ 1
- 1
src/ge/ge_runtime/task/stream_switch_task.cc View File

@@ -51,7 +51,7 @@ bool StreamSwitchTask::Distribute() {
}

if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) {
GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(),
GELOGE(PARAM_INVALID, "true_stream_id %ld must less than stream_list_ size %zu!", task_info_->true_stream_id(),
stream_list_.size());
return false;
}


+ 6
- 0
src/ge/ge_runtime/task/task.h View File

@@ -18,7 +18,9 @@
#define GE_GE_RUNTIME_TASK_TASK_H_

#include <memory>
#include <utility>
#include <vector>
#include <string>
#include "runtime/rt_model.h"
#include "ge_runtime/model_context.h"
#include "ge_runtime/task_info.h"
@@ -32,6 +34,10 @@ class Task {
virtual ~Task() {}

virtual bool Distribute() = 0;

virtual void *Args() { return nullptr; }

virtual std::string task_name() const { return ""; }
};

template <class T>


+ 3
- 4
src/ge/ge_runtime/task/tbe_task.cc View File

@@ -95,15 +95,14 @@ bool TbeTask::Distribute() {
return false;
}

GELOGI("InitTbeTask end.");
GELOGI("DistributeTbeTask start.");
rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_);
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTbeTask end.");
GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag);
return true;
}



+ 4
- 0
src/ge/ge_runtime/task/tbe_task.h View File

@@ -30,6 +30,10 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> {

bool Distribute() override;

void *Args() override { return args_; }

std::string task_name() const override { return task_info_->op_name(); }

private:
std::shared_ptr<TbeTaskInfo> task_info_;
void *stream_;


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save