Browse Source

!53 code sync for C75B100 master

Merge pull request !53 from HW_KK/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 4 years ago
parent
commit
622af6c1c5
100 changed files with 2808 additions and 1453 deletions
  1. +2
    -0
      build.sh
  2. +18
    -0
      inc/common/opskernel/ge_task_info.h
  3. +5
    -4
      inc/common/opskernel/ops_kernel_info_store.h
  4. +5
    -1
      inc/common/opskernel/ops_kernel_info_types.h
  5. +2
    -0
      inc/common/optimizer/graph_optimizer.h
  6. +1
    -0
      inc/common/util/compress/compress.h
  7. +33
    -0
      inc/common/util/compress/compress_weight.h
  8. +18
    -7
      inc/common/util/error_manager/error_manager.h
  9. +4
    -2
      inc/common/util/platform_info.h
  10. +14
    -0
      inc/common/util/platform_info_def.h
  11. +1
    -1
      inc/external/ge/ge_api_error_codes.h
  12. +16
    -4
      inc/external/ge/ge_api_types.h
  13. +2
    -0
      inc/external/graph/attr_value.h
  14. +13
    -10
      inc/external/graph/operator.h
  15. +1
    -0
      inc/external/graph/operator_reg.h
  16. +1
    -0
      inc/external/graph/tensor.h
  17. +2
    -1
      inc/external/graph/types.h
  18. +4
    -0
      inc/external/register/register.h
  19. +0
    -24
      inc/framework/common/debug/ge_log.h
  20. +20
    -37
      inc/framework/common/debug/log.h
  21. +21
    -2
      inc/framework/common/ge_inner_error_codes.h
  22. +5
    -3
      inc/framework/common/ge_types.h
  23. +0
    -2
      inc/framework/common/helper/model_helper.h
  24. +3
    -3
      inc/framework/common/string_util.h
  25. +20
    -0
      inc/framework/common/types.h
  26. +56
    -10
      inc/framework/executor/ge_executor.h
  27. +11
    -1
      inc/framework/ge_runtime/model_runner.h
  28. +80
    -55
      inc/framework/ge_runtime/task_info.h
  29. +1
    -0
      inc/framework/generator/ge_generator.h
  30. +56
    -0
      inc/framework/memory/memory_api.h
  31. +0
    -5
      inc/framework/omg/omg.h
  32. +1
    -0
      inc/framework/omg/omg_inner_types.h
  33. +3
    -3
      inc/graph/buffer.h
  34. +19
    -2
      inc/graph/compute_graph.h
  35. +49
    -1
      inc/graph/debug/ge_attr_define.h
  36. +3
    -1
      inc/graph/detail/any_map.h
  37. +15
    -3
      inc/graph/detail/attributes_holder.h
  38. +3
    -0
      inc/graph/detail/model_serialize_imp.h
  39. +3
    -3
      inc/graph/ge_attr_value.h
  40. +1
    -0
      inc/graph/ge_context.h
  41. +5
    -2
      inc/graph/ge_tensor.h
  42. +0
    -1
      inc/graph/model_serialize.h
  43. +1
    -1
      inc/graph/node.h
  44. +9
    -4
      inc/graph/op_desc.h
  45. +1
    -0
      inc/graph/shape_refiner.h
  46. +40
    -4
      inc/graph/utils/graph_utils.h
  47. +28
    -0
      inc/graph/utils/node_utils.h
  48. +1
    -0
      inc/graph/utils/tensor_adapter.h
  49. +1
    -0
      inc/graph/utils/tensor_utils.h
  50. +1
    -0
      src/common/graph/CMakeLists.txt
  51. +109
    -25
      src/common/graph/compute_graph.cc
  52. +4
    -0
      src/common/graph/debug/ge_op_types.h
  53. +69
    -21
      src/common/graph/format_refiner.cc
  54. +4
    -4
      src/common/graph/format_refiner.h
  55. +49
    -1
      src/common/graph/ge_attr_define.cc
  56. +26
    -29
      src/common/graph/ge_attr_value.cc
  57. +11
    -0
      src/common/graph/ge_tensor.cc
  58. +1
    -1
      src/common/graph/graph.cc
  59. +68
    -8
      src/common/graph/graph.mk
  60. +97
    -27
      src/common/graph/model_serialize.cc
  61. +38
    -19
      src/common/graph/node.cc
  62. +89
    -112
      src/common/graph/op_desc.cc
  63. +46
    -48
      src/common/graph/operator.cc
  64. +2
    -0
      src/common/graph/opsproto/opsproto_manager.cc
  65. +2
    -0
      src/common/graph/option/ge_context.cc
  66. +47
    -4
      src/common/graph/ref_relation.cc
  67. +1
    -0
      src/common/graph/runtime_inference_context.cc
  68. +257
    -36
      src/common/graph/shape_refiner.cc
  69. +0
    -6
      src/common/graph/stub/Makefile
  70. +0
    -573
      src/common/graph/stub/gen_stubapi.py
  71. +14
    -10
      src/common/graph/tensor.cc
  72. +202
    -24
      src/common/graph/utils/graph_utils.cc
  73. +217
    -18
      src/common/graph/utils/node_utils.cc
  74. +52
    -24
      src/common/graph/utils/op_desc_utils.cc
  75. +6
    -2
      src/common/graph/utils/tensor_utils.cc
  76. +2
    -1
      src/common/graph/utils/type_utils.cc
  77. +26
    -4
      src/ge/CMakeLists.txt
  78. +14
    -43
      src/ge/client/ge_api.cc
  79. +6
    -5
      src/ge/common/convert/pb2json.cc
  80. +0
    -1
      src/ge/common/formats/format_transfers/datatype_transfer.cc
  81. +0
    -1
      src/ge/common/formats/format_transfers/datatype_transfer.h
  82. +0
    -1
      src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc
  83. +0
    -1
      src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc
  84. +1
    -1
      src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc
  85. +55
    -40
      src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc
  86. +1
    -3
      src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc
  87. +0
    -2
      src/ge/common/formats/utils/formats_definitions.h
  88. +0
    -2
      src/ge/common/formats/utils/formats_trans_utils.h
  89. +1
    -1
      src/ge/common/fp16_t.h
  90. +81
    -0
      src/ge/common/ge/op_tiling_manager.cc
  91. +38
    -0
      src/ge/common/ge/op_tiling_manager.h
  92. +5
    -2
      src/ge/common/ge/tbe_plugin_manager.cc
  93. +1
    -1
      src/ge/common/ge/tbe_plugin_manager.h
  94. +1
    -0
      src/ge/common/ge_common.mk
  95. +31
    -125
      src/ge/common/helper/model_helper.cc
  96. +17
    -14
      src/ge/common/helper/om_file_helper.cc
  97. +1
    -1
      src/ge/common/math/fp16_math.h
  98. +0
    -2
      src/ge/common/math_util.h
  99. +16
    -13
      src/ge/common/model_parser/base.cc
  100. +501
    -0
      src/ge/common/model_parser/graph_parser_util.cc

+ 2
- 0
build.sh View File

@@ -174,9 +174,11 @@ echo "---------------- GraphEngine output generated ----------------"
# generate output package in tar form, including ut/st libraries/executables
cd ${BASEPATH}
mkdir -p output/plugin/nnengine/ge_config/
mkdir -p output/plugin/opskernel/
find output/ -name graphengine_lib.tar -exec rm {} \;
cp src/ge/engine_manager/engine_conf.json output/plugin/nnengine/ge_config/
find output/ -maxdepth 1 -name libengine.so -exec mv -f {} output/plugin/nnengine/ \;
find output/ -maxdepth 1 -name libge_local_engine.so -exec mv -f {} output/plugin/opskernel/ \;
tar -cf graphengine_lib.tar output/*
mv -f graphengine_lib.tar output
echo "---------------- GraphEngine package archive generated ----------------"

+ 18
- 0
inc/common/opskernel/ge_task_info.h View File

@@ -52,5 +52,23 @@ struct GETaskInfo {

std::vector<GETaskKernelHcclInfo> kernelHcclInfo;
};

struct HcomOpertion {
std::string hcclType;
void *inputPtr;
void *outputPtr;
uint64_t count;
int32_t dataType;
int32_t opType;
int32_t root;
};

struct HcomRemoteAccessAddrInfo {
uint32_t remotetRankID;
uint64_t remoteAddr; // host embedding table address
uint64_t localAddr; // device HBM address
uint64_t length; // memory Length in Bytes
};

} // namespace ge
#endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_

+ 5
- 4
inc/common/opskernel/ops_kernel_info_store.h View File

@@ -43,10 +43,10 @@ class OpsKernelInfoStore {
virtual ~OpsKernelInfoStore() {}

// initialize opsKernelInfoStore
virtual Status Initialize(const map<string, string> &options) = 0;
virtual Status Initialize(const map<string, string> &options) = 0; /*lint -e148*/

// close opsKernelInfoStore
virtual Status Finalize() = 0;
virtual Status Finalize() = 0; /*lint -e148*/

virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; }

@@ -66,10 +66,11 @@ class OpsKernelInfoStore {
virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){};

// memory allocation requirement
virtual Status CalcOpRunningParam(Node &node) = 0;
virtual Status CalcOpRunningParam(Node &node) = 0; /*lint -e148*/

// generate task for op。
virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0;
virtual Status GenerateTask(const Node &node, RunContext &context,
std::vector<domi::TaskDef> &tasks) = 0; /*lint -e148*/

// only call fe engine interface to compile single op
virtual Status CompileOp(vector<ge::NodePtr> &node_vec) { return SUCCESS; }


+ 5
- 1
inc/common/opskernel/ops_kernel_info_types.h View File

@@ -26,6 +26,7 @@
using std::string;

namespace ge {
/*lint -e148*/
struct RunContext {
rtModel_t model;
rtStream_t stream;
@@ -40,6 +41,8 @@ struct RunContext {
std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...)
};

/*lint +e148*/

struct Task {
uint32_t id;
uint16_t type;
@@ -48,7 +51,8 @@ struct Task {
};

struct OpInfo {
string engine; // which engin
string engine; // which engin
/*lint -e148*/
string opKernelLib; // which opsKernelStore
int computeCost; // compute cost
bool flagPartial; // whether to support is related to shape


+ 2
- 0
inc/common/optimizer/graph_optimizer.h View File

@@ -27,6 +27,7 @@
using std::map;
using std::string;

/*lint -e148*/
namespace ge {
class GraphOptimizer {
public:
@@ -60,4 +61,5 @@ class GraphOptimizer {
virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { return SUCCESS; }
};
} // namespace ge
/*lint +e148*/
#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_

+ 1
- 0
inc/common/util/compress/compress.h View File

@@ -28,6 +28,7 @@ struct CompressConfig {
size_t channel; // channels of L2 or DDR. For load balance
size_t fractalSize; // size of compressing block
bool isTight; // whether compose compressed data tightly
size_t init_offset;
};

CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output,


+ 33
- 0
inc/common/util/compress/compress_weight.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMPRESS_WEIGHT_H
#define COMPRESS_WEIGHT_H

#include "compress.h"

const int SHAPE_SIZE_WEIGHT = 4;

struct CompressOpConfig {
int64_t wShape[SHAPE_SIZE_WEIGHT];
size_t compressTilingK;
size_t compressTilingN;
struct CompressConfig compressConfig;
};

extern "C" CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer, char *const infoBuffer,
CompressOpConfig *const param);
#endif // COMPRESS_WEIGHT_H

+ 18
- 7
inc/common/util/error_manager/error_manager.h View File

@@ -31,27 +31,37 @@ class ErrorManager {

///
/// @brief init
/// @param [in] path current so path
/// @param [in] path: current so path
/// @return int 0(success) -1(fail)
///
int Init(std::string path);

///
/// @brief Report error message
/// @param [in] errCode error code
/// @param [in] mapArgs parameter map
/// @param [in] error_code: error code
/// @param [in] args_map: parameter map
/// @return int 0(success) -1(fail)
///
int ReportErrMessage(std::string error_code, const std::map<std::string, std::string> &args_map);

///
/// @brief output error message
/// @param [in] handle print handle
/// @param [in] handle: print handle
/// @return int 0(success) -1(fail)
///
int OutputErrMessage(int handle);

///
/// @brief output message
/// @param [in] handle: print handle
/// @return int 0(success) -1(fail)
///
int OutputMessage(int handle);

///
/// @brief Report error message
/// @param [in] vector parameter key, vector parameter value
/// @param [in] key: vector parameter key
/// @param [in] value: vector parameter value
///
void ATCReportErrMessage(std::string error_code, const std::vector<std::string> &key = {},
const std::vector<std::string> &value = {});
@@ -60,7 +70,7 @@ class ErrorManager {
struct ErrorInfo {
std::string error_id;
std::string error_message;
std::vector<std::string> arglist;
std::vector<std::string> arg_list;
};

ErrorManager() {}
@@ -77,7 +87,8 @@ class ErrorManager {

bool is_init_ = false;
std::map<std::string, ErrorInfo> error_map_;
std::vector<std::string> error_message_evc_;
std::vector<std::string> error_messages_;
std::vector<std::string> warning_messages_;
};

#endif // ERROR_MANAGER_H_

+ 4
- 2
inc/common/util/platform_info.h View File

@@ -27,7 +27,6 @@ using std::string;
using std::vector;

namespace fe {

class PlatformInfoManager {
public:
PlatformInfoManager(const PlatformInfoManager &) = delete;
@@ -39,6 +38,8 @@ class PlatformInfoManager {

uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo);

uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo);

void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo);

private:
@@ -81,6 +82,8 @@ class PlatformInfoManager {

void ParseVectorCoreMemoryRates(map<string, string> &vectorCoreMemoryRatesMap, PlatformInfo &platformInfoTemp);

void ParseCPUCache(map<string, string> &CPUCacheMap, PlatformInfo &platformInfoTemp);

void ParseVectorCoreintrinsicDtypeMap(map<string, string> &vectorCoreintrinsicDtypeMap,
PlatformInfo &platformInfoTemp);

@@ -94,6 +97,5 @@ class PlatformInfoManager {
map<string, PlatformInfo> platformInfoMap_;
OptionalInfo optiCompilationInfo_;
};

} // namespace fe
#endif

+ 14
- 0
inc/common/util/platform_info_def.h View File

@@ -73,6 +73,8 @@ typedef struct tagAiCoreSpec {

typedef struct tagAiCoreMemoryRates {
double ddrRate;
double ddrReadRate;
double ddrWriteRate;
double l2Rate;
double l2ReadRate;
double l2WriteRate;
@@ -86,6 +88,7 @@ typedef struct tagAiCoreMemoryRates {
} AiCoreMemoryRates;

typedef struct tagVectorCoreSpec {
double vecFreq;
uint64_t vecCalcSize;
uint64_t smaskBuffer;
uint64_t ubSize;
@@ -94,10 +97,15 @@ typedef struct tagVectorCoreSpec {
uint64_t ubbankNum;
uint64_t ubburstInOneBlock;
uint64_t ubbankGroupNum;
uint64_t vectorRegSize;
uint64_t predicateRegSize;
uint64_t addressRegSize;
} VectorCoreSpec;

typedef struct tagVectorCoreMemoryRates {
double ddrRate;
double ddrReadRate;
double ddrWriteRate;
double l2Rate;
double l2ReadRate;
double l2WriteRate;
@@ -105,6 +113,11 @@ typedef struct tagVectorCoreMemoryRates {
double ubToDdrRate;
} VectorCoreMemoryRates;

typedef struct tagCPUCache {
uint32_t AICPUSyncBySW;
uint32_t TSCPUSyncBySW;
} CPUCache;

typedef struct tagPlatformInfo {
StrInfo strInfo;
SoCInfo socInfo;
@@ -113,6 +126,7 @@ typedef struct tagPlatformInfo {
map<string, vector<string>> aiCoreIntrinsicDtypeMap;
VectorCoreSpec vectorCoreSpec;
VectorCoreMemoryRates vectorCoreMemoryRates;
CPUCache cpucache;
map<string, vector<string>> vectorCoreIntrinsicDtypeMap;
} PlatformInfo;



+ 1
- 1
inc/external/ge/ge_api_error_codes.h View File

@@ -70,7 +70,7 @@ using Status = uint32_t;

// General error code
GE_ERRORNO(0, 0, 0, 0, 0, SUCCESS, 0, "success");
GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed");
GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed"); /*lint !e401*/
} // namespace ge

#endif // INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_

+ 16
- 4
inc/external/ge/ge_api_types.h View File

@@ -44,8 +44,11 @@ const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump";
const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath";
const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep";
const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode";
const char *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug";
const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode";
const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild";
const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath";
const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses";
// profiling flag
const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode";
const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions";
@@ -170,6 +173,9 @@ const char *const kDynamicBatchSize = "ge.dynamicBatchSize";
// configure whether to use dynamic image size
const char *const kDynamicImageSize = "ge.dynamicImageSize";

// Configure whether to use dynamic dims
const char *const kDynamicDims = "ge.dynamicDims";

// Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y,
// example: GA|RL, support configure multiple, split by |
const std::string AUTO_TUNE_MODE = "ge.autoTuneMode";
@@ -219,6 +225,10 @@ const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream";
// Configure input fp16 nodes
const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16";

// Configure debug level, its value should be 0(default), 1 or 2.
// 0: close debug; 1: open TBE compiler; 2: open ccec compiler
const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel";

// Graph run mode
enum GraphRunMode { PREDICTION = 0, TRAIN };

@@ -261,6 +271,7 @@ static const char *const INPUT_SHAPE = "input_shape";
static const char *const OP_NAME_MAP = "op_name_map";
static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize;
static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize;
static const char *const DYNAMIC_DIMS = kDynamicDims;
static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str();
static const char *const PRECISION_MODE = ge::PRECISION_MODE.c_str();
static const char *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY;
@@ -283,10 +294,11 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c

// for interface: aclgrphBuildModel
const std::set<std::string> ir_builder_suppported_options = {
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE,
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE, OUTPUT_TYPE, OUT_NODES, INPUT_FP16_NODES,
LOG_LEVEL};
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP,
DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE, DYNAMIC_DIMS,
INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE, OUTPUT_TYPE, OUT_NODES,
INPUT_FP16_NODES, LOG_LEVEL};
// for interface: aclgrphBuildInitialize
const std::set<std::string> global_options = {CORE_TYPE,
SOC_VERSION,


+ 2
- 0
inc/external/graph/attr_value.h View File

@@ -34,6 +34,7 @@ using std::vector;

namespace ge {
class AttrValueImpl;
/*lint -e148*/
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue {
public:
using INT = int64_t;
@@ -69,5 +70,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue {
VALUE_SET_GET_DEC(AttrValue::FLOAT)
#undef VALUE_SET_GET_DEC
};
/*lint +e148*/
} // namespace ge
#endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_

+ 13
- 10
inc/external/graph/operator.h View File

@@ -61,6 +61,7 @@ using std::function;
using std::shared_ptr;
using std::string;

/*lint -e148*/
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {
public:
friend class OperatorImpl;
@@ -88,7 +89,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {

explicit Operator(const string &type);

Operator(const string &name, const string &type);
Operator(const string &name, const string &type); // lint !e148

virtual ~Operator() = default;

@@ -101,7 +102,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {
// Only has one output index = 0
Operator &SetInput(const string &dst_name, const Operator &src_oprt);

Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name);
Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148

Operator &AddControlInput(const Operator &src_oprt);

@@ -123,22 +124,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {

TensorDesc GetOutputDesc(uint32_t index) const;

graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc);
graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148

TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const;

graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc);
graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148

TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const;

graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc);
graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148

graphStatus InferShapeAndType();
graphStatus InferShapeAndType(); // lint !e148

void SetInferenceContext(const InferenceContextPtr &inference_context);
InferenceContextPtr GetInferenceContext() const;

graphStatus VerifyAllAttr(bool disable_common_verifier = false);
graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148

size_t GetInputsSize() const;

@@ -251,19 +252,20 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {

void RequiredAttrRegister(const string &name);

graphStatus VerifyAll();
graphStatus VerifyAll(); // lint !e148

// Only has one output index = 0
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt);

Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name);
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt,
const string &name); // lint !e148

void SubgraphRegister(const string &ir_name, bool dynamic);
void SubgraphCountRegister(const string &ir_name, uint32_t count);
void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder);

private:
Operator &SetInput(const string &dst_name, const OutHandler &out_handler);
Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148

OutHandler GetOutput(const string &name) const;

@@ -273,6 +275,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {

graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const;
};
/*lint +e148*/
} // namespace ge

#endif // INC_EXTERNAL_GRAPH_OPERATOR_H_

+ 1
- 0
inc/external/graph/operator_reg.h View File

@@ -343,6 +343,7 @@ class OpReg {
auto x_type = op.GetInputDesc(in_name).GetDataType(); \
TensorDesc op_output_desc = op.GetOutputDesc(out_name); \
op_output_desc.SetShape(ge::Shape(x_shape)); \
op_output_desc.SetOriginShape(ge::Shape(x_shape)); \
op_output_desc.SetDataType(x_type); \
return op.UpdateOutputDesc(out_name, op_output_desc); \
}


+ 1
- 0
inc/external/graph/tensor.h View File

@@ -126,5 +126,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor {
friend class TensorAdapter;
};
} // namespace ge
/*lint +e148*/

#endif // INC_EXTERNAL_GRAPH_TENSOR_H_

+ 2
- 1
inc/external/graph/types.h View File

@@ -145,7 +145,8 @@ enum Format {
FORMAT_FRACTAL_ZN_LSTM,
FORMAT_FRACTAL_Z_G,
FORMAT_RESERVED,
FORMAT_ALL
FORMAT_ALL,
FORMAT_NULL
};

// for unknown shape op type


+ 4
- 0
inc/external/register/register.h View File

@@ -40,6 +40,7 @@ using std::to_string;
using std::unique_ptr;
using std::vector;

/*lint -e148*/
namespace ge {
class Operator;
class TensorDesc;
@@ -98,6 +99,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {

OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type);

OpRegistrationData &InputReorderVector(const vector<int> &input_order);

domi::ImplyType GetImplyType() const;
std::string GetOmOptype() const;
std::set<std::string> GetOriginOpTypeSet() const;
@@ -130,4 +133,5 @@ namespace ge {
using OpRegistrationData = domi::OpRegistrationData;
using OpReceiver = domi::OpReceiver;
} // namespace ge
/*lint +e148*/
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_

+ 0
- 24
inc/framework/common/debug/ge_log.h View File

@@ -51,30 +51,6 @@ inline pid_t GetTid() {
return tid;
}

#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap()

#define GE_TIMESTAMP_END(stage, stage_name) \
do { \
uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \
(endUsec_##stage - startUsec_##stage)); \
} while (0);

#define GE_TIMESTAMP_CALLNUM_START(stage) \
uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \
uint64_t call_num_of##stage = 0; \
uint64_t time_of##stage = 0

#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap())

#define GE_TIMESTAMP_ADD(stage) \
time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \
call_num_of##stage++

#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \
call_num_of##stage)

#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \
dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \
((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__)


+ 20
- 37
inc/framework/common/debug/log.h View File

@@ -19,15 +19,12 @@

#include <string>

#include "cce/cce_def.hpp"
#include "runtime/rt.h"
#include "common/string_util.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "ge/ge_api_error_codes.h"

using cce::CC_STATUS_SUCCESS;
using cce::ccStatus_t;

#if !defined(__ANDROID__) && !defined(ANDROID)
#define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__)
#else
@@ -102,17 +99,13 @@ using cce::ccStatus_t;
} while (0);

// If expr is not true, print the log and return the specified status
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
std::string msg; \
(void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \
(void)msg.append( \
ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \
DOMI_LOGE("%s", msg.c_str()); \
return _status; \
} \
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
GELOGE(_status, __VA_ARGS__); \
return _status; \
} \
} while (0);

// If expr is not true, print the log and return the specified status
@@ -132,7 +125,7 @@ using cce::ccStatus_t;
DOMI_LOGE(__VA_ARGS__); \
exec_expr; \
} \
};
}

// If expr is not true, print the log and execute a custom statement
#define GE_CHK_BOOL_EXEC_WARN(expr, exec_expr, ...) \
@@ -142,7 +135,7 @@ using cce::ccStatus_t;
GELOGW(__VA_ARGS__); \
exec_expr; \
} \
};
}
// If expr is not true, print the log and execute a custom statement
#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \
{ \
@@ -151,7 +144,7 @@ using cce::ccStatus_t;
GELOGI(__VA_ARGS__); \
exec_expr; \
} \
};
}

// If expr is not true, print the log and execute a custom statement
#define GE_CHK_BOOL_TRUE_EXEC_INFO(expr, exec_expr, ...) \
@@ -161,7 +154,7 @@ using cce::ccStatus_t;
GELOGI(__VA_ARGS__); \
exec_expr; \
} \
};
}

// If expr is true, print logs and execute custom statements
#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \
@@ -171,7 +164,7 @@ using cce::ccStatus_t;
DOMI_LOGE(__VA_ARGS__); \
exec_expr; \
} \
};
}
// If expr is true, print the Information log and execute a custom statement
#define GE_CHK_TRUE_EXEC_INFO(expr, exec_expr, ...) \
{ \
@@ -180,7 +173,7 @@ using cce::ccStatus_t;
GELOGI(__VA_ARGS__); \
exec_expr; \
} \
};
}

// If expr is not SUCCESS, print the log and execute the expression + return
#define GE_CHK_BOOL_TRUE_RET_VOID(expr, exec_expr, ...) \
@@ -191,7 +184,7 @@ using cce::ccStatus_t;
exec_expr; \
return; \
} \
};
}

// If expr is not SUCCESS, print the log and execute the expression + return _status
#define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \
@@ -202,7 +195,7 @@ using cce::ccStatus_t;
exec_expr; \
return _status; \
} \
};
}

// If expr is not true, execute a custom statement
#define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \
@@ -211,7 +204,7 @@ using cce::ccStatus_t;
if (!b) { \
exec_expr; \
} \
};
}

// -----------------runtime related macro definitions-------------------------------
// If expr is not RT_ERROR_NONE, print the log
@@ -231,7 +224,7 @@ using cce::ccStatus_t;
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \
exec_expr; \
} \
};
}

// If expr is not RT_ERROR_NONE, print the log and return
#define GE_CHK_RT_RET(expr) \
@@ -239,27 +232,17 @@ using cce::ccStatus_t;
rtError_t _rt_ret = (expr); \
if (_rt_ret != RT_ERROR_NONE) { \
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \
return ge::RT_FAILED; \
return RT_ERROR_TO_GE_STATUS(_rt_ret); \
} \
} while (0);

// ------------------------cce related macro definitions----------------------------
// If expr is not CC_STATUS_SUCCESS, print the log
#define GE_CHK_CCE(expr) \
do { \
ccStatus_t _cc_ret = (expr); \
if (_cc_ret != CC_STATUS_SUCCESS) { \
DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \
} \
} while (0);

// If expr is true, execute exec_expr without printing logs
#define GE_IF_BOOL_EXEC(expr, exec_expr) \
{ \
if (expr) { \
exec_expr; \
} \
};
}

// If make_shared is abnormal, print the log and execute the statement
#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \


+ 21
- 2
inc/framework/common/ge_inner_error_codes.h View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/

/*lint -e* */
#ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_
#define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_

@@ -280,8 +281,24 @@ GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_REDUCE_SCATTER_FAILED, 47, "call hccl hcom r

// Executor module error code definition
GE_ERRORNO_EXECUTOR(GE_EXEC_NOT_INIT, 1, "GE Executor is not yet initialized.");
GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 2, "GE AIPP is not exist.");
GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 3, "GE Dynamic AIPP is not support to query temporarily.");
GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PATH_INVALID, 2, "Model file path is invalid.");
GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_KEY_PATH_INVALID, 3, "Key file path of model is invalid.");
GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_ID_INVALID, 4, "Model id is invalid.");
GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_DATA_SIZE_INVALID, 5, "Data size of model is invalid.");
GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PARTITION_NUM_INVALID, 6, "Partition number of model is invalid.");
GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_QUEUE_ID_INVALID, 7, "Queue id of model is invalid.");
GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, 8, "Model does not support encryption.");
GE_ERRORNO_EXECUTOR(GE_EXEC_READ_MODEL_FILE_FAILED, 9, "Failed to read model file.");
GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_REPEATED, 10, "The model is loaded repeatedly.");
GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_PARTITION_FAILED, 11, "Failed to load model partition.");
GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED, 12, "Failed to load weight partition.");
GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_TASK_PARTITION_FAILED, 13, "Failed to load task partition.");
GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_KERNEL_PARTITION_FAILED, 14, "Failed to load kernel partition.");
GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, 15, "Failed to allocate feature map memory.");
GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, 16, "Failed to allocate weight memory.");
GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_VAR_MEM_FAILED, 17, "Failed to allocate variable memory.");
GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 18, "GE AIPP is not exist.");
GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 19, "GE Dynamic AIPP is not support to query temporarily.");

// Generator module error code definition
GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, 1, "Graph manager initialize failed.");
@@ -289,6 +306,8 @@ GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, 2, "Graph mana
GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, 3, "Graph manager build graph failed.");
GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, 4, "Graph manager finalize failed.");
GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_SAVE_MODEL_FAILED, 5, "Graph manager save model failed.");

#define RT_ERROR_TO_GE_STATUS(RT_ERROR) static_cast<Status>(RT_ERROR)
} // namespace ge

#endif // INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_

+ 5
- 3
inc/framework/common/ge_types.h View File

@@ -54,9 +54,9 @@ const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM";
struct DataBuffer {
public:
void *data; // Data address
uint32_t length; // Data length
uint64_t length; // Data length
bool isDataSupportMemShare = false;
DataBuffer(void *dataIn, uint32_t len, bool isSupportMemShare)
DataBuffer(void *dataIn, uint64_t len, bool isSupportMemShare)
: data(dataIn), length(len), isDataSupportMemShare(isSupportMemShare) {}

DataBuffer() : data(nullptr), length(0), isDataSupportMemShare(false) {}
@@ -106,7 +106,7 @@ struct ShapeDescription {
// Definition of input and output description information
struct InputOutputDescInfo {
std::string name;
uint32_t size;
uint64_t size;
uint32_t data_type;
ShapeDescription shape_info;
};
@@ -231,6 +231,7 @@ struct Options {

// Profiling info of task
struct TaskDescInfo {
std::string model_name;
std::string op_name;
uint32_t block_dim;
uint32_t task_id;
@@ -239,6 +240,7 @@ struct TaskDescInfo {

// Profiling info of graph
struct ComputeGraphDescInfo {
std::string model_name;
std::string op_name;
std::string op_type;
std::vector<Format> input_format;


+ 0
- 2
inc/framework/common/helper/model_helper.h View File

@@ -44,8 +44,6 @@ class ModelHelper {
void SetSaveMode(bool val) { is_offline_ = val; }
bool GetSaveMode(void) const { return is_offline_; }

static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model);
static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr);
Status GetBaseNameFromFileName(const std::string& file_name, std::string& base_name);
Status GetModelNameFromMergedGraphName(const std::string& graph_name, std::string& model_name);



+ 3
- 3
inc/framework/common/string_util.h View File

@@ -36,8 +36,8 @@ class StringUtils {
#endif
return s;
}
static std::string &Rtrim(std::string &s) {
// lint -esym(551,*)
static std::string &Rtrim(std::string &s) { /*lint !e618*/
#if __cplusplus >= 201103L
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); }));
#else
@@ -45,7 +45,7 @@ class StringUtils {
#endif
return s;
}
// lint -esym(551,*)
///
/// @ingroup domi_common
/// @brief delete spaces at the beginning and end of a string


+ 20
- 0
inc/framework/common/types.h View File

@@ -48,6 +48,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_S
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_AICORE;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ATOMIC;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ALL;

// Supported public properties name
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time
@@ -335,6 +338,8 @@ REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell");
REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext");
REGISTER_OPTYPE_DECLARE(INITDATA, "InitData");
REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape")
REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity");
REGISTER_OPTYPE_DECLARE(BITCAST, "Bitcast");

// ANN dedicated operator
REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean");
@@ -428,6 +433,8 @@ REGISTER_OPTYPE_DECLARE(HCOMALLREDUCE, "HcomAllReduce");
REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter");
REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend");
REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite");

REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign");
REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp");
@@ -554,6 +561,16 @@ enum ModelCheckType {
UNCHECK // no verification
};

///
/// @brief dynamic input type
///
enum DynamicInputType {
FIXED = 0, // default mode
DYNAMIC_BATCH = 1,
DYNAMIC_IMAGE = 2,
DYNAMIC_DIMS = 3
};

///
/// @brief magic number of the model file
///
@@ -631,6 +648,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_N

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_END_GRAPH;

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_OP_DEBUG;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_OP_DEBUG;

// convolution node type
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_CONVOLUTION;
// adds a convolutional node name for the hard AIPP


+ 56
- 10
inc/framework/executor/ge_executor.h View File

@@ -21,28 +21,31 @@
#include <string>
#include <vector>

#include "common/dynamic_aipp.h"
#include "common/ge_inner_error_codes.h"
#include "common/ge_types.h"
#include "common/types.h"
#include "graph/tensor.h"
#include "graph/ge_tensor.h"
#include "runtime/base.h"
#include "common/dynamic_aipp.h"

namespace ge {
class ModelListenerAdapter;

class SingleOp;
class DynamicSingleOp;

struct RunModelData {
uint32_t index; // Data index
uint32_t modelId;
std::vector<DataBuffer> blobs; // All input/output data buffer
uint32_t timestamp; // Data creation time
uint32_t timeout; // Processing timeout
uint64_t request_id = 0; // Request ID
uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0
uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0
uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0
std::vector<DataBuffer> blobs; // All input/output data buffer
uint32_t timestamp; // Data creation time
uint32_t timeout; // Processing timeout
uint64_t request_id = 0; // Request ID
uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0
uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0
uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0
std::vector<uint64_t> dynamic_dims; // Dynamic dims scene, set dynamic dims, not supported by default:empty
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
@@ -87,16 +90,52 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
///
ge::Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height,
uint64_t image_width);

///
/// @ingroup ge
/// @brief Set dynamic dims info
/// @param [in] model_id: model id allocate from manager
/// @param [in] dynamic_input_addr: dynamic input addr created by user
/// @param [in] length: length of dynamic input addr
/// @param [in] dynamic_dim_num: number of dynamic dimension
/// @param [in] dynamic_dims: array of dynamic dimensions
/// @return execute result
///
ge::Status SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, uint64_t length,
const std::vector<uint64_t> &dynamic_dims);

///
/// @ingroup ge
/// @brief Get current dynamic dims info by combined dims
/// @param [in] model_id: model id allocate from manager
/// @param [in] combined_dims: array of combined dimensions
/// @param [out] cur_dynamic_dims: current dynamic dims
/// @return execute result
///
ge::Status GetCurDynamicDims(uint32_t model_id, const std::vector<uint64_t> &combined_dims,
std::vector<uint64_t> &cur_dynamic_dims);

///
/// @ingroup ge
/// @brief Get dynamic batch_info
/// @param [in] model_id
/// @param [out] batch_info
/// @param [out] dynamic_type
/// @return execute result
///
ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info);
ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info,
int32_t &dynamic_type);

ge::Status GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info);
///
/// @ingroup ge
/// @brief Get combined dynamic dims info
/// @param [in] model_id
/// @param [out] batch_info
/// @return execute result
///
ge::Status GetCombinedDynamicDims(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info);

ge::Status GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info, int32_t &dynamic_type);

///
/// @ingroup ge
@@ -209,6 +248,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
static ge::Status ExecuteAsync(SingleOp *executor, const std::vector<DataBuffer> &inputs,
std::vector<DataBuffer> &outputs);

static ge::Status LoadDynamicSingleOp(const std::string &model_name, const ge::ModelData &modelData, void *stream,
DynamicSingleOp **single_op);

static ge::Status ExecuteAsync(DynamicSingleOp *executor, const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &inputs, std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &outputs);

static ge::Status ReleaseSingleOpResource(void *stream);

ge::Status GetBatchInfoSize(uint32_t model_id, size_t &shape_count);


+ 11
- 1
inc/framework/ge_runtime/model_runner.h View File

@@ -28,7 +28,7 @@
namespace ge {
namespace model_runner {
class RuntimeModel;
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class ModelRunner {
public:
static ModelRunner &Instance();
@@ -36,8 +36,18 @@ class ModelRunner {
bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener);

bool DistributeTask(uint32_t model_id);

bool LoadModelComplete(uint32_t model_id);

const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;

const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;

const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const;

void *GetModelHandle(uint32_t model_id) const;

bool UnloadModel(uint32_t model_id);

bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);


+ 80
- 55
inc/framework/ge_runtime/task_info.h View File

@@ -21,6 +21,7 @@
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cce/taskdown_api.h"
@@ -52,21 +53,27 @@ class TaskInfo {
virtual ~TaskInfo() {}
uint32_t stream_id() const { return stream_id_; }
TaskInfoType type() const { return type_; }
std::string op_name() const { return op_name_; }
bool dump_flag() const { return dump_flag_; }

protected:
TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {}
TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
: op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}

private:
std::string op_name_;
uint32_t stream_id_;
TaskInfoType type_;
bool dump_flag_;
};

class CceTaskInfo : public TaskInfo {
public:
CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc,
const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable)
: TaskInfo(stream_id, TaskInfoType::CCE),
CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
const std::vector<uint8_t> &args_offset, bool is_flowtable)
: TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
ctx_(ctx),
stub_func_(stub_func),
block_dim_(block_dim),
@@ -102,11 +109,11 @@ class CceTaskInfo : public TaskInfo {

class TbeTaskInfo : public TaskInfo {
public:
TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args,
uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size,
const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs)
: TaskInfo(stream_id, TaskInfoType::TBE),
TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
stub_func_(stub_func),
block_dim_(block_dim),
args_(args),
@@ -153,9 +160,10 @@ class TbeTaskInfo : public TaskInfo {

class AicpuTaskInfo : public TaskInfo {
public:
AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def,
const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs)
: TaskInfo(stream_id, TaskInfoType::AICPU),
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
const std::string &node_def, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
so_name_(so_name),
kernel_name_(kernel_name),
node_def_(node_def),
@@ -177,37 +185,45 @@ class AicpuTaskInfo : public TaskInfo {
std::vector<void *> output_data_addrs_;
};

class LabelTaskInfo : public TaskInfo {
class LabelSetTaskInfo : public TaskInfo {
public:
LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
~LabelSetTaskInfo() override {}
uint32_t label_id() const { return label_id_; }

protected:
LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id)
: TaskInfo(stream_id, type), label_id_(label_id) {}
virtual ~LabelTaskInfo() override {}

private:
uint32_t label_id_;
};

class LabelSetTaskInfo : public LabelTaskInfo {
class LabelGotoTaskInfo : public TaskInfo {
public:
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {}
~LabelSetTaskInfo() override {}
LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
~LabelGotoTaskInfo() override {}
uint32_t label_id() const { return label_id_; }

private:
uint32_t label_id_;
};

class LabelSwitchTaskInfo : public LabelTaskInfo {
class LabelSwitchTaskInfo : public TaskInfo {
public:
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {}
LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
const std::vector<uint32_t> &label_list, void *cond)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
label_size_(label_size),
label_list_(label_list),
cond_(cond) {}
~LabelSwitchTaskInfo() override {}
};
uint32_t label_size() { return label_size_; };
const std::vector<uint32_t> &label_list() { return label_list_; };
void *cond() { return cond_; };

class LabelGotoTaskInfo : public LabelTaskInfo {
public:
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {}
~LabelGotoTaskInfo() override {}
private:
uint32_t label_size_;
std::vector<uint32_t> label_list_;
void *cond_;
};

class EventTaskInfo : public TaskInfo {
@@ -215,8 +231,8 @@ class EventTaskInfo : public TaskInfo {
uint32_t event_id() const { return event_id_; }

protected:
EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(stream_id, type), event_id_(event_id) {}
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
virtual ~EventTaskInfo() override {}

uint32_t event_id_;
@@ -224,39 +240,41 @@ class EventTaskInfo : public TaskInfo {

class EventRecordTaskInfo : public EventTaskInfo {
public:
EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
~EventRecordTaskInfo() override {}
};

class EventWaitTaskInfo : public EventTaskInfo {
public:
EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
~EventWaitTaskInfo() override {}
};

class FusionStartTaskInfo : public TaskInfo {
public:
explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {}
explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
~FusionStartTaskInfo() override {}
};

class FusionEndTaskInfo : public TaskInfo {
public:
explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {}
explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
~FusionEndTaskInfo() override {}
};

class HcclTaskInfo : public TaskInfo {
public:
HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr,
void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
int64_t op_type, int64_t data_type, std::function<bool(void *, void *)> hcom_bind_model,
std::function<bool(void *)> hcom_unbind_model,
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task)
: TaskInfo(stream_id, TaskInfoType::HCCL),
int64_t op_type, int64_t data_type, const std::string &group,
std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
hccl_type_(hccl_type),
input_data_addr_(input_data_addr),
output_data_addr_(output_data_addr),
@@ -269,6 +287,7 @@ class HcclTaskInfo : public TaskInfo {
root_id_(root_id),
op_type_(op_type),
data_type_(data_type),
group_(group),
hcom_bind_model_(hcom_bind_model),
hcom_unbind_model_(hcom_unbind_model),
hcom_distribute_task_(hcom_distribute_task) {}
@@ -286,6 +305,7 @@ class HcclTaskInfo : public TaskInfo {
int64_t root_id() const { return root_id_; }
int64_t op_type() const { return op_type_; }
int64_t data_type() const { return data_type_; }
const std::string &group() const { return group_; }
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {
@@ -305,6 +325,7 @@ class HcclTaskInfo : public TaskInfo {
int64_t root_id_;
int64_t op_type_;
int64_t data_type_;
std::string group_;
std::function<bool(void *, void *)> hcom_bind_model_;
std::function<bool(void *)> hcom_unbind_model_;
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
@@ -312,8 +333,11 @@ class HcclTaskInfo : public TaskInfo {

class ProfilerTraceTaskInfo : public TaskInfo {
public:
ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {}
ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
log_id_(log_id),
notify_(notify),
flat_(flat) {}
~ProfilerTraceTaskInfo() override {}

uint64_t log_id() const { return log_id_; }
@@ -328,8 +352,9 @@ class ProfilerTraceTaskInfo : public TaskInfo {

class MemcpyAsyncTaskInfo : public TaskInfo {
public:
MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind)
: TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC),
MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
uint64_t count, uint32_t kind, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
dst_(dst),
dst_max_(dst_max),
src_(src),
@@ -353,9 +378,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo {

class StreamSwitchTaskInfo : public TaskInfo {
public:
StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond,
int64_t data_type)
: TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH),
StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
void *value_addr, int64_t cond, int64_t data_type)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
true_stream_id_(true_stream_id),
input_addr_(input_addr),
value_addr_(value_addr),
@@ -379,8 +404,8 @@ class StreamSwitchTaskInfo : public TaskInfo {

class StreamActiveTaskInfo : public TaskInfo {
public:
StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {}
StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
~StreamActiveTaskInfo() override {}

uint32_t active_stream_id() const { return active_stream_id_; }


+ 1
- 0
inc/framework/generator/ge_generator.h View File

@@ -27,6 +27,7 @@
#include "graph/ge_tensor.h"
#include "graph/graph.h"
#include "graph/op_desc.h"
#include "graph/detail/attributes_holder.h"

namespace ge {
class GeGenerator {


+ 56
- 0
inc/framework/memory/memory_api.h View File

@@ -0,0 +1,56 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_MEMORY_MEMORY_API_H_
#define INC_FRAMEWORK_MEMORY_MEMORY_API_H_

#include <string>
#include <vector>

#include "ge/ge_api_error_codes.h"
#include "runtime/mem.h"

namespace ge {
enum MemStorageType {
HBM = 0,
RDMA_HBM,
};

struct HostVarInfo {
uint64_t base_addr;
uint64_t var_size;
};

///
/// \param size [in] rdma pool memory size to be allocated.
/// \param mem_type [in] memory type for rdma pool.
/// \return Status result of function
Status InitRdmaPool(size_t size, rtMemType_t mem_type = RT_MEMORY_HBM);

///
/// \param var_info [in] host variable addr infos.
/// \param mem_type [in] memory type for rdma pool.
/// \return Status result of function
Status RdmaRemoteRegister(const std::vector<HostVarInfo> &var_info, rtMemType_t mem_type = RT_MEMORY_HBM);

///
/// \param var_name [in] var_name name of host variable.
/// \param base_addr [out] base_addr vase addr of host variable.
/// \param var_size [out] var_size memory_size of host variable.
/// \return Status result of function
Status GetVarBaseAddrAndSize(const std::string &var_name, uint64_t &base_addr, uint64_t &var_size);
} // namespace ge
#endif // INC_FRAMEWORK_MEMORY_MEMORY_API_H_

+ 0
- 5
inc/framework/omg/omg.h View File

@@ -96,17 +96,12 @@ Status CheckCustomAiCpuOpLib();

Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file);

Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format);

Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);

void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);

void UpdateOmgCtxWithParserCtx();

void UpdateParserCtxWithOmgCtx();

} // namespace ge

namespace domi {


+ 1
- 0
inc/framework/omg/omg_inner_types.h View File

@@ -120,6 +120,7 @@ struct OmgContext {
bool is_dynamic_input = false;
std::string dynamic_batch_size;
std::string dynamic_image_size;
std::string dynamic_dims;
};
} // namespace ge



+ 3
- 3
inc/graph/buffer.h View File

@@ -57,11 +57,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer {

// For compatibility
inline const std::uint8_t *data() const { return GetData(); }
inline std::uint8_t *data() { return GetData(); }
inline std::uint8_t *data() { return GetData(); } // lint !e659
inline std::size_t size() const { return GetSize(); }
inline void clear() { return ClearBuffer(); }
uint8_t operator[](size_t index) const {
if (buffer_ != nullptr && index < buffer_->size()) {
uint8_t operator[](size_t index) const { // lint !e1022 !e1042
if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574
return (uint8_t)(*buffer_)[index];
}
return 0xff;


+ 19
- 2
inc/graph/compute_graph.h View File

@@ -74,6 +74,9 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A

size_t GetAllNodesSize() const;
Vistor<NodePtr> GetAllNodes() const;
// is_unknown_shape: false, same with GetAllNodes func
// is_unknown_shape: true, same with GetDirectNodes func
Vistor<NodePtr> GetNodes(bool is_unknown_shape) const;
size_t GetDirectNodesSize() const;
Vistor<NodePtr> GetDirectNode() const;
Vistor<NodePtr> GetInputNodes() const;
@@ -81,14 +84,18 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A

NodePtr FindNode(const std::string &name) const;
NodePtr FindFirstNodeMatchType(const std::string &name) const;
/*lint -e504*/
// AddNode with NodePtr
NodePtr AddNode(NodePtr node);
NodePtr AddNode(OpDescPtr op);
NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize.
NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize
NodePtr AddNodeFront(NodePtr node);
NodePtr AddNodeFront(const OpDescPtr &op);
NodePtr AddInputNode(NodePtr node);
NodePtr AddOutputNode(NodePtr node);
// insert node with specific pre_node
NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node);
NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node);

graphStatus RemoveNode(const NodePtr &node);
graphStatus RemoveInputNode(const NodePtr &node);
@@ -133,6 +140,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
bool IsValid() const;
void Dump() const;

void Swap(ComputeGraph &graph);

graphStatus IsolateNode(const NodePtr &node);
graphStatus Verify();
graphStatus InferShape();
@@ -141,6 +150,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
graphStatus InsertEventNodes();
bool operator==(const ComputeGraph &r_compute_graph) const;

/*lint +e504*/
const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const {
return params_share_map_;
}
@@ -174,6 +184,10 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
void SetInputSize(uint32_t size) { input_size_ = size; }
uint32_t GetInputSize() const { return input_size_; }

// false: known shape true: unknow shape
bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; }
void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; }

///
/// Set is need train iteration.
/// If set true, it means this graph need to be run iteration some
@@ -249,6 +263,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector,
const std::vector<NodePtr> &l_node_ptr_vector) const;

void SetNodesOwner();

friend class ModelSerializeImp;
friend class GraphDebugImp;
friend class OnnxUtils;
@@ -282,7 +298,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
std::map<uint32_t, std::string> op_name_map_;
uint64_t session_id_ = 0;
ge::Format data_format_ = ge::FORMAT_ND;
// unknown graph indicator, default is false, mean known shape
bool is_unknown_shape_graph_ = false;
};
} // namespace ge

#endif // INC_GRAPH_COMPUTE_GRAPH_H_

+ 49
- 1
inc/graph/debug/ge_attr_define.h View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/

/*lint -e618*/
#ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_
#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_

@@ -185,6 +186,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_ORIGIN_SIZE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_INPUT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_OUTPUT;

// to be deleted
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION;
@@ -778,6 +782,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET;
@@ -930,12 +938,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_LABEL;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_BATCH;

// Control flow
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG;
@@ -979,6 +989,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEE
// For mutil-batch
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_TYPE;

// For inserted op
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE;
@@ -996,7 +1007,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE;

// used for l1 fusion and other fusion in future
// used for lX fusion
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY;
@@ -1010,9 +1021,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE;

// for unregistered op
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST;

// op overflow dump
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE;

// functional ops attr
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH;
@@ -1058,6 +1081,31 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOR
// for gradient group
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG;

// dynamic shape attrs
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX;

// atc user def dtype&format
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT;

// for fusion op plugin
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE;

// graph partition for aicpu
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME;

// input and output memory type
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VARIABLE_PLACEMENT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE;

// input_output_offset
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET;
} // namespace ge

#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_
/*lint +e618*/

+ 3
- 1
inc/graph/detail/any_map.h View File

@@ -38,7 +38,7 @@ class TypeID {
bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; }

private:
explicit TypeID(string type) : type_(std::move(type)) {}
explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32

string type_;
};
@@ -53,6 +53,8 @@ class AnyMap {

bool Has(const string &name) const { return anyValues_.find(name) != anyValues_.end(); }

void Swap(AnyMap &other) { anyValues_.swap(other.anyValues_); }

private:
class Placeholder {
public:


+ 15
- 3
inc/graph/detail/attributes_holder.h View File

@@ -50,7 +50,7 @@ class OpDef;
class GraphDef;
} // namespace proto

using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>;
using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073
using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>;

template <class ProtoType>
@@ -95,6 +95,14 @@ class GeIrProtoHelper {
}
}

void Swap(GeIrProtoHelper<ProtoType> &other) {
protoOwner_.swap(other.protoOwner_);

ProtoType *temp = protoMsg_;
protoMsg_ = other.protoMsg_;
other.protoMsg_ = temp;
}

// protoMsg_ is part of protoOwner_, they have the same runtime
ProtoMsgOwner protoOwner_ = nullptr;
ProtoType *protoMsg_ = nullptr;
@@ -120,6 +128,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder {

void CopyAttrsFrom(const AttrHolder &holder);

void Swap(AttrHolder &holder) {
requiredAttrs_.swap(holder.requiredAttrs_);
extAttrs_.Swap(holder.extAttrs_);
}

template <class T>
bool SetExtAttr(const string &name, const T &value) {
return extAttrs_.Set(name, value);
@@ -134,7 +147,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder {
protected:
graphStatus AddRequiredAttr(const std::string &name);
const std::unordered_set<string> GetAllAttrNames() const;
const std::map<string, GeAttrValue> GetAllAttrs() const;
const std::map<string, GeAttrValue> GetAllAttrs() const; // lint !e1073

virtual ProtoAttrMapHelper MutableAttrMap() = 0;
virtual ConstProtoAttrMapHelper GetAttrMap() const = 0;
@@ -149,5 +162,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder {
AnyMap extAttrs_;
};
} // namespace ge

#endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_

+ 3
- 0
inc/graph/detail/model_serialize_imp.h View File

@@ -67,6 +67,9 @@ class ModelSerializeImp {
bool HandleNodeNameRef();

bool UnserializeOpDesc(OpDescPtr &opDesc, proto::OpDef &opDefProto);
void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector<string> &key_in, std::vector<string> &key_out,
std::vector<uint32_t> &value_in, std::vector<uint32_t> &value_out, std::vector<string> &opt);
void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto);

bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &opDefProto);



+ 3
- 3
inc/graph/ge_attr_value.h View File

@@ -310,7 +310,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue {
VALUE_SET_GET_DEC(GeAttrValue::GRAPH)
VALUE_SET_GET_DEC(BYTES)
VALUE_SET_GET_DEC(NamedAttrs)
VALUE_SET_GET_DEC(ge::DataType)
VALUE_SET_GET_DEC(ge::DataType) // lint !e665
VALUE_SET_GET_DEC(vector<GeAttrValue::STR>)
VALUE_SET_GET_DEC(vector<GeAttrValue::INT>)
VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>)
@@ -320,8 +320,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue {
VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>)
VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>)
VALUE_SET_GET_DEC(vector<NamedAttrs>)
VALUE_SET_GET_DEC(vector<vector<int64_t>>)
VALUE_SET_GET_DEC(vector<ge::DataType>)
VALUE_SET_GET_DEC(vector<vector<int64_t>>) // lint !e665
VALUE_SET_GET_DEC(vector<ge::DataType>) // lint !e665
#undef VALUE_SET_GET_DEC

GeIrProtoHelper<proto::AttrDef> value_;


+ 1
- 0
inc/graph/ge_context.h View File

@@ -28,6 +28,7 @@ class GEContext {
uint32_t DeviceId();
uint64_t TraceId();
void Init();
void SetSessionId(uint64_t session_id);
void SetCtxDeviceId(uint32_t device_id);

private:


+ 5
- 2
inc/graph/ge_tensor.h View File

@@ -25,6 +25,7 @@
#include "graph/buffer.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"

namespace ge {
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape {
public:
@@ -108,8 +109,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH
DataType GetDataType() const;
void SetDataType(DataType dt);

void SetOriginDataType(DataType originDataType);
DataType GetOriginDataType() const;
void SetOriginDataType(DataType originDataType);

std::vector<uint32_t> GetRefPortIndex() const;
void SetRefPortByIndex(const std::vector<uint32_t> &index);

GeTensorDesc Clone() const;
GeTensorDesc &operator=(const GeTensorDesc &desc);
@@ -186,5 +190,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor {
GeTensorDesc &DescReference() const;
};
} // namespace ge

#endif // INC_GRAPH_GE_TENSOR_H_

+ 0
- 1
inc/graph/model_serialize.h View File

@@ -49,5 +49,4 @@ class ModelSerialize {
friend class GraphDebugImp;
};
} // namespace ge

#endif // INC_GRAPH_MODEL_SERIALIZE_H_

+ 1
- 1
inc/graph/node.h View File

@@ -190,7 +190,7 @@ class Node : public std::enable_shared_from_this<Node> {
vector<OutDataAnchorPtr> out_data_anchors_;
InControlAnchorPtr in_control_anchor_;
OutControlAnchorPtr out_control_anchor_;
map<string, GeAttrValue> attrs_;
map<string, GeAttrValue> attrs_; // lint !e1073
bool has_init_{false};
bool anchor_status_updated_{false};
std::vector<uint32_t> send_event_id_list_;


+ 9
- 4
inc/graph/op_desc.h View File

@@ -105,6 +105,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

GeTensorDescPtr MutableInputDesc(uint32_t index) const;

GeTensorDescPtr MutableInputDesc(const string &name) const;

Vistor<GeTensorDesc> GetAllInputsDesc() const;

Vistor<GeTensorDescPtr> GetAllInputsDescPtr() const;
@@ -127,6 +129,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

GeTensorDescPtr MutableOutputDesc(uint32_t index) const;

GeTensorDescPtr MutableOutputDesc(const string &name) const;

uint32_t GetAllOutputsDescSize() const;

Vistor<GeTensorDesc> GetAllOutputsDesc() const;
@@ -149,16 +153,15 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true);

void RemoveInputDesc(uint32_t index);
void RemoveOutputDesc(uint32_t index);

bool IsOptionalInput(const string &name) const;

bool IsOptionalInput(uint32_t index) const;

std::map<string, uint32_t> GetAllInputName() const;

void SetAllInputName(const std::map<string, uint32_t> &input_name_idx);

std::vector<string> GetAllOptionalInputName() const;

std::map<string, uint32_t> GetAllOutputName();

bool UpdateInputName(std::map<string, uint32_t> inputNameIdx);
@@ -296,6 +299,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {
std::map<std::string, SubgraphType> subgraph_ir_names_to_type_;

vector<GeTensorDescPtr> inputs_desc_{};
map<string, uint32_t> input_name_idx_{};
std::unordered_set<string> optional_input_names_{};
vector<GeTensorDescPtr> outputs_desc_{};
map<string, uint32_t> output_name_idx_{};
std::function<graphStatus(Operator &)> infer_func_ = nullptr;


+ 1
- 0
inc/graph/shape_refiner.h View File

@@ -31,6 +31,7 @@ class ShapeRefiner {
static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph);
static graphStatus InferShapeAndType(const NodePtr &node);
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op);
static void ClearContextMap();

private:
static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase);


+ 40
- 4
inc/graph/utils/graph_utils.h View File

@@ -23,6 +23,8 @@
#include <string>
#include <vector>
#include <list>
#include <unordered_map>

#include "graph/anchor.h"
#include "graph/node.h"
#include "graph/compute_graph.h"
@@ -130,7 +132,7 @@ struct NodeIndexIO {
IOType io_type_ = kOut;
std::string value_;

std::string ToString() const { return value_; }
const std::string &ToString() const { return value_; }
};

class GraphUtils {
@@ -188,8 +190,8 @@ class GraphUtils {
/// @param [in] output_index
/// @return graphStatus
///
static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);
static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);

static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node);

@@ -303,8 +305,33 @@ class GraphUtils {
///
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node);

///
/// Copy all in-data edges from `src_node` to `dst_node`
/// @param src_node
/// @param dst_node
/// @return
///
static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node);

static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph);

///
/// Make a copy of ComputeGraph.
/// @param graph: original graph.
/// @param prefix: node name prefix of new graph.
/// @return ComputeGraphPtr
///
static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix,
std::vector<NodePtr> &input_nodes, std::vector<NodePtr> &output_nodes);

///
/// Copy tensor attribute to new node.
/// @param [in] dst_desc: cloned node.
/// @param [in] src_node: original node.
/// @return success: GRAPH_SUCESS
///
static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node);

static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec);

///
@@ -392,6 +419,16 @@ class GraphUtils {
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol);

///
/// Relink all edges for cloned ComputeGraph.
/// @param [in] node: original node.
/// @param [in] prefix: node name prefix of new node.
/// @param [in] all_nodes: all nodes in new graph.
/// @return success: GRAPH_SUCESS
///
static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix,
const std::unordered_map<string, NodePtr> &all_nodes);

///
/// Union ref-mapping
/// @param [in] exist_node_info1
@@ -728,5 +765,4 @@ class PartialGraphBuilder : public ComputeGraphBuilder {
std::vector<NodePtr> exist_nodes_;
};
} // namespace ge

#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_

+ 28
- 0
inc/graph/utils/node_utils.h View File

@@ -63,6 +63,9 @@ class NodeUtils {
static void UnlinkAll(const Node &node);
static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr);

static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t index);
static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t index);

static bool IsInNodesEmpty(const Node &node);
static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index);
static GeTensorDesc GetInputDesc(const Node &node, uint32_t index);
@@ -99,6 +102,13 @@ class NodeUtils {
///
static NodePtr GetParentInput(const NodePtr &node);

///
/// @brief Check is varying_input for while node
/// @param [in] node: Data node for subgraph
/// @return bool
///
static bool IsWhileVaryingInput(const ge::NodePtr &node);

///
/// @brief Get subgraph input is constant.
/// @param [in] node
@@ -114,6 +124,24 @@ class NodeUtils {
///
static graphStatus RemoveSubgraphsOnNode(const NodePtr &node);

///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
static vector<NodePtr> GetSubgraphDataNodesByIndex(const Node &node, int index);

///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
static vector<NodePtr> GetSubgraphOutputNodes(const Node &node);

static NodePtr GetInDataNodeByIndex(const Node &node, int index);

static vector<NodePtr> GetOutDataNodesByIndex(const Node &node, int index);

private:
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_;
static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_;


+ 1
- 0
inc/graph/utils/tensor_adapter.h View File

@@ -20,6 +20,7 @@
#include <memory>
#include "graph/ge_tensor.h"
#include "graph/tensor.h"

namespace ge {
using GeTensorPtr = std::shared_ptr<GeTensor>;
using ConstGeTensorPtr = std::shared_ptr<const GeTensor>;


+ 1
- 0
inc/graph/utils/tensor_utils.h View File

@@ -21,6 +21,7 @@
#include "graph/def_types.h"
#include "graph/ge_error_codes.h"
#include "graph/ge_tensor.h"

namespace ge {
class TensorUtils {
public:


+ 1
- 0
src/common/graph/CMakeLists.txt View File

@@ -71,5 +71,6 @@ target_link_libraries(graph PRIVATE
${PROTOBUF_LIBRARY}
${c_sec}
${slog}
${error_manager}
rt
dl)

+ 109
- 25
src/common/graph/compute_graph.cc View File

@@ -62,18 +62,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() co
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const {
size_t s = nodes_.size();
for (const auto &sub_graph : sub_graph_) {
s += sub_graph->GetAllNodesSize();
}
return s;
return GetAllNodes().size();
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const {
if (sub_graph_.empty()) {
return Vistor<NodePtr>(shared_from_this(), nodes_);
}

std::vector<std::shared_ptr<ComputeGraph>> subgraphs;
return AllGraphNodes(subgraphs);
}
@@ -106,6 +98,15 @@ ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::share
return Vistor<NodePtr>(shared_from_this(), all_nodes);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetNodes(
bool is_unknown_shape) const {
if (is_unknown_shape) {
return GetDirectNode();
} else {
return GetAllNodes();
}
}

size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const {
@@ -268,7 +269,7 @@ NodePtr ComputeGraph::AddNodeFront(NodePtr node) {

NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) {
if (op == nullptr) {
GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null.");
GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
return nullptr;
}
op->SetId(nodes_.size());
@@ -278,9 +279,38 @@ NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) {
return AddNodeFront(node_ptr);
}

NodePtr ComputeGraph::AddNodeAfter(NodePtr node, const NodePtr &pre_node) {
if (node == nullptr || node->GetOpDesc() == nullptr || pre_node == nullptr) {
GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null.");
return nullptr;
}
node->GetOpDesc()->SetId(nodes_.size());
auto node_iter = std::find(nodes_.begin(), nodes_.end(), pre_node);
if (node_iter != nodes_.end()) {
nodes_.insert(node_iter + 1, node);
} else {
GELOGE(GRAPH_FAILED, "Cannot find pre_node in nodes_.");
return nullptr;
}

return node;
}

NodePtr ComputeGraph::AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node) {
if (op == nullptr) {
GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
return nullptr;
}
op->SetId(nodes_.size());
NodePtr node_ptr = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this()));
GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr);
GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init failed."); return nullptr);
return AddNodeAfter(node_ptr, pre_node);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) {
if (node == nullptr || node->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
return nullptr;
}
node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize());
@@ -290,7 +320,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(Nod

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) {
if (op == nullptr) {
GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null.");
GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
return nullptr;
}
op->SetId(GetDirectNodesSize());
@@ -302,7 +332,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpD

NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize.
if (op == nullptr) {
GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null.");
GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null.");
return nullptr;
}
op->SetId(id);
@@ -315,7 +345,7 @@ NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize.

NodePtr ComputeGraph::AddInputNode(NodePtr node) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
return nullptr;
}
input_nodes_.push_back(node);
@@ -327,7 +357,7 @@ NodePtr ComputeGraph::AddInputNode(NodePtr node) {

NodePtr ComputeGraph::AddOutputNode(NodePtr node) {
if (node == nullptr || node->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "The node ptr or opdesc should be not null.");
GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null.");
return nullptr;
}

@@ -363,7 +393,7 @@ graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) {
if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) {
GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED,
"Remove edge from const op failed.");
if (out_anchor->GetOwnerNode()->GetOutDataNodes().size() == 0) {
if (out_anchor->GetOwnerNode()->GetOutNodes().size() == 0) {
GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str());
auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode());
if (iter != nodes_.end()) {
@@ -377,7 +407,7 @@ graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) {

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
return GRAPH_FAILED;
}

@@ -406,7 +436,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveN
// Used in sub_graph scenes
graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
return GRAPH_FAILED;
}

@@ -421,7 +451,7 @@ graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) {
// Used in sub_graph scenes
graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "The node ptr should be not null.");
GELOGE(GRAPH_FAILED, "The node ptr should not be null.");
return GRAPH_FAILED;
}

@@ -442,7 +472,7 @@ graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) {

std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph) {
if (sub_graph == nullptr) {
GELOGE(GRAPH_FAILED, "The graph ptr should be not null.");
GELOGE(GRAPH_FAILED, "The graph ptr should not be null.");
return nullptr;
}
sub_graph_.push_back(sub_graph);
@@ -452,7 +482,7 @@ std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeG

graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph) {
if (sub_graph == nullptr) {
GELOGE(GRAPH_FAILED, "The graph ptr should be not null.");
GELOGE(GRAPH_FAILED, "The graph ptr should not be null.");
return GRAPH_FAILED;
}

@@ -491,12 +521,15 @@ ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<Compute
return GRAPH_PARAM_INVALID;
}
if (!this->parent_graph_.expired()) {
GE_LOGE("The subgraphs can only be added to the root graph");
return GRAPH_PARAM_INVALID;
GELOGW("The subgraphs should only be added to the root graph");
}
if (name != subgraph->GetName()) {
GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str());
}
if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) {
GE_LOGE("The subgraph %s existed", name.c_str());
return GRAPH_PARAM_INVALID;
}
sub_graph_.push_back(subgraph);
names_to_subgraph_[name] = subgraph;
return GRAPH_SUCCESS;
@@ -640,7 +673,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE
GELOGW("node or OpDescPtr is nullptr.");
continue;
}
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should be not null."); return GRAPH_FAILED);
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should not be null."); return GRAPH_FAILED);
if (node->GetOpDesc()->GetType() == RECV) {
auto iter = find(node_vec.begin(), node_vec.end(), node);
if (iter == node_vec.end()) {
@@ -786,7 +819,8 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() {
auto ret = TopologicalSortingGraph();
if (ret != SUCCESS) {
GELOGE(ret, "Sub graph partition Failed");
GraphUtils::DumpGEGraphToOnnx(*this, "black_box");
GELOGE(ret, "Graph [%s] topological sort failed, saved to file black_box", name_.c_str());
return ret;
}

@@ -1001,6 +1035,54 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const {
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Swap(ComputeGraph &graph) {
this->AttrHolder::Swap(graph);

origGraph_.swap(graph.origGraph_);

name_.swap(graph.name_);
std::swap(graph_id_, graph.graph_id_);
attrs_.Swap(graph.attrs_);
nodes_.swap(graph.nodes_);
all_nodes_infos_.swap(graph.all_nodes_infos_);
target_nodes_info_.swap(graph.target_nodes_info_);

input_nodes_.swap(graph.input_nodes_);
inputs_order_.swap(graph.inputs_order_);
std::swap(input_size_, graph.input_size_);
out_nodes_map_.swap(graph.out_nodes_map_);
std::swap(output_size_, graph.output_size_);
output_nodes_info_.swap(graph.output_nodes_info_);

sub_graph_.swap(graph.sub_graph_);
names_to_subgraph_.swap(graph.names_to_subgraph_);
parent_graph_.swap(graph.parent_graph_);
parent_node_.swap(graph.parent_node_);

// the members followed should not in the ComputeGraph class
std::swap(is_valid_flag_, graph.is_valid_flag_);
std::swap(is_summary_graph_, graph.is_summary_graph_);
std::swap(need_iteration_, graph.need_iteration_);
params_share_map_.swap(graph.params_share_map_);
op_name_map_.swap(graph.op_name_map_);
std::swap(session_id_, graph.session_id_);
std::swap(data_format_, graph.data_format_);
std::swap(is_unknown_shape_graph_, graph.is_unknown_shape_graph_);

// Update Node owner.
SetNodesOwner();
graph.SetNodesOwner();
}

void ComputeGraph::SetNodesOwner() {
for (const auto &node : nodes_) {
if (node == nullptr) {
continue;
}
node->SetOwnerComputeGraph(shared_from_this());
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto next_nodes = node->GetOutAllNodes();
@@ -1104,9 +1186,11 @@ graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) {
}

graphStatus ComputeGraph::Verify() {
bool is_unknown_graph = GetGraphUnknownFlag();
for (const auto &node_ptr : GetAllNodes()) {
GE_CHECK_NOTNULL(node_ptr);
GE_CHECK_NOTNULL(node_ptr->GetOpDesc());
GE_IF_BOOL_EXEC(is_unknown_graph, continue);
GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED,
"Verifying %s failed.", node_ptr->GetName().c_str());
}


+ 4
- 0
src/common/graph/debug/ge_op_types.h View File

@@ -34,12 +34,16 @@ GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims");
GE_REGISTER_OPTYPE(SWITCH, "Switch");
GE_REGISTER_OPTYPE(MERGE, "Merge");
GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge");
GE_REGISTER_OPTYPE(ENTER, "Enter");
GE_REGISTER_OPTYPE(REFENTER, "RefEnter");
GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration");
GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration");
GE_REGISTER_OPTYPE(CONSTANT, "Const");
GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder");
GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp");
GE_REGISTER_OPTYPE(GETNEXT, "GetNext");
GE_REGISTER_OPTYPE(INITDATA, "InitData");
GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity");
GE_REGISTER_OPTYPE(ANN_DATA, "AnnData");

GE_REGISTER_OPTYPE(CONSTANTOP, "Constant");


+ 69
- 21
src/common/graph/format_refiner.cc View File

@@ -41,11 +41,9 @@ using namespace ge;
using namespace std;
namespace ge {
namespace {
static const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE};
static bool net_format_is_nd = true;
static Format g_user_set_format = FORMAT_ND;
static bool is_first_infer = true;
static RefRelations reflection_builder;
const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE};
const string kIsGraphInferred = "_is_graph_inferred";
RefRelations reflection_builder;
} // namespace

graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection,
@@ -72,9 +70,49 @@ graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &re
return GRAPH_SUCCESS;
}

graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) {
graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) {
// 5 meas dim num
if (node_ptr->GetType() != "BiasAdd") {
return GRAPH_SUCCESS;
}
std::unordered_map<string, Format> kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}};
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) {
auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i);
GE_CHECK_NOTNULL(in_desc);
if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num
continue;
}
auto format = in_desc->GetOriginFormat();
auto key = TypeUtils::FormatToSerialString(format);
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key];
in_desc->SetOriginFormat(fixed_format);
in_desc->SetFormat(fixed_format);
GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i,
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::FormatToSerialString(fixed_format).c_str());
}
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) {
auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i);
GE_CHECK_NOTNULL(out_desc);
if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num
continue;
}
auto format = out_desc->GetOriginFormat();
auto key = TypeUtils::FormatToSerialString(format);
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key];
out_desc->SetOriginFormat(fixed_format);
out_desc->SetFormat(fixed_format);
GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i,
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::FormatToSerialString(fixed_format).c_str());
}
return GRAPH_SUCCESS;
}

graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) {
GE_CHECK_NOTNULL(graph);
GE_CHECK_NOTNULL(op_desc);
if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) {
if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) {
ConstGeTensorPtr tensor_value;
if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) {
GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str());
@@ -95,7 +133,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
}
anchor_points.clear();
// Get all anchor point nodes and switch nodes
for (const auto &node_ptr : graph->GetAllNodes()) {
for (auto &node_ptr : graph->GetAllNodes()) {
if (node_ptr == nullptr) {
return GRAPH_FAILED;
}
@@ -103,7 +141,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
if (op_desc == nullptr) {
return GRAPH_FAILED;
}
graphStatus status = RefreshConstantOutProcess(op_desc);
graphStatus status = RefreshConstantOutProcess(graph, op_desc);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "refresh constant out process failed!");
return GRAPH_FAILED;
@@ -135,6 +173,16 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
if (!node_is_all_nd) {
continue;
}
// special process for biasAdd op
// In tensorflow, biasAdd's format is alwayse NHWC even though set the arg
// "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism
// so here do special process
status = BiasAddFormatFixProcess(node_ptr);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "fix biasAdd process failed!");
return GRAPH_FAILED;
}

GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str());
anchor_points.push_back(node_ptr);
}
@@ -344,14 +392,11 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor
}
}

void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; }

graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format,
graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes,
ge::Format data_format,
std::unordered_map<ge::NodePtr, bool> &node_status) {
bool is_internal_format = TypeUtils::IsInternalFormat(data_format);
bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND);
if (!need_process) {
GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer,
if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) {
GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph),
TypeUtils::FormatToSerialString(data_format).c_str());
return GRAPH_SUCCESS;
}
@@ -410,8 +455,6 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph)
std::vector<ge::NodePtr> anchor_points;
std::vector<ge::NodePtr> data_nodes;
// global net format
net_format_is_nd = true;
g_user_set_format = FORMAT_ND;

if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "input graph is null");
@@ -448,10 +491,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph)
/// format for these data nodes.
/// Notice: ignore 5D formats
auto data_format = graph->GetDataFormat();
status = DataNodeFormatProcess(data_nodes, data_format, node_status);
// Set infer flag to false
SetInferOrigineFormatFlag(false);
status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status);
(void)AttrUtils::SetBool(graph, kIsGraphInferred, true);

return status;
}

bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) {
bool is_graph_inferred = false;
return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred);
}
} // namespace ge

+ 4
- 4
src/common/graph/format_refiner.h View File

@@ -30,10 +30,9 @@ namespace ge {
class FormatRefiner {
public:
static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph);
static void SetInferOrigineFormatFlag(bool is_first = true);

private:
static graphStatus RefreshConstantOutProcess(const OpDescPtr &op_desc);
static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc);
static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points,
std::vector<ge::NodePtr> &data_nodes,
std::unordered_map<ge::NodePtr, bool> &node_status);
@@ -43,8 +42,9 @@ class FormatRefiner {
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node,
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format,
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes,
ge::Format data_format, std::unordered_map<ge::NodePtr, bool> &node_status);
static bool IsGraphInferred(const ComputeGraphPtr &graph);
};
} // namespace ge
#endif // COMMON_GRAPH_FORMAT_REFINER_H_

+ 49
- 1
src/common/graph/ge_attr_define.cc View File

@@ -158,6 +158,10 @@ const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size";
const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS = "_dynamic_output_dims";
const std::string ATTR_NAME_INPUT_ORIGIN_SIZE = "input_origin_size";

// Identify node connecting to input and output
const std::string ATTR_NAME_NODE_CONNECT_INPUT = "_is_connected_to_data";
const std::string ATTR_NAME_NODE_CONNECT_OUTPUT = "_is_connected_to_netoutput";

// To be deleted
const std::string ATTR_TO_BE_DELETED = "to_be_deleted";
const std::string PERMUTE_RESHAPE_FUSION = "permute_reshape_fusion";
@@ -725,6 +729,10 @@ const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name";

const std::string ATTR_MODEL_CORE_TYPE = "core_type";

const std::string ATTR_MODEL_ATC_VERSION = "atc_version";

const std::string ATTR_MODEL_OPP_VERSION = "opp_version";

// Public attribute
const std::string ATTR_NAME_IMPLY_TYPE = "imply_type";

@@ -901,6 +909,7 @@ const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_l
const std::string ATTR_NAME_PRED_VALUE = "_pred_value";
const std::string ATTR_NAME_BATCH_NUM = "_batch_num";
const std::string ATTR_NAME_BATCH_LABEL = "_batch_label";
const std::string ATTR_NAME_COMBINED_BATCH = "_combined_batch";

// Control flow
const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition";
@@ -910,6 +919,7 @@ const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value";
const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop";
const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node";
const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active";
const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS = "combined_dynamic_dims";

const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label";
const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag";
@@ -934,7 +944,7 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace";

const std::string MODEL_ATTR_SESSION_ID = "session_id";

// l1 fusion and other fusion in future
// lx fusion
const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id";
const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key";
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key";
@@ -948,9 +958,17 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1
const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion";
const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split";
const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed";
const std::string ATTR_DATA_DUMP_REF = "_datadump_ref";
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion";
const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id";
const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion";
const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag";
const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr";
const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size";

// Op debug attrs
const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag";
const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode";

// Atomic addr clean attrs
const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index";
@@ -971,6 +989,8 @@ const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node";

const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims";

const std::string ATTR_DYNAMIC_TYPE = "mbatch_dynamic_type";

// For inserted op
const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge";

@@ -1009,10 +1029,38 @@ const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_
const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list";
const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_output_offset_list_list";

// for unregistered op
const std::string ATTR_NAME_UNREGST_OPPATH = "_unregst_oppath";
const std::string ATTR_NAME_UNREGST_ATTRLIST = "_unregst_attrlist";

// used for Horovod
const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id";
const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op";
// used for allreduce tailing optimization
const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group";
const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node";

// dynamic shape attr
const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr";
const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index";

// atc user def dtype&format
const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type";
const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format";

// for fusion op plugin
const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type";

// graph partition for aicpu
const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME = "pld_front_node_engine_name";
const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME = "end_rear_node_engine_name";

// input and output memory type
const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement";
const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type";
const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type";

// input_output_offset
const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset";
const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset";
} // namespace ge

+ 26
- 29
src/common/graph/ge_attr_value.cc View File

@@ -33,7 +33,8 @@ using std::vector;
namespace ge {
NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); }

NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) : named_attrs_(owner, proto_msg) {}
NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg)
: named_attrs_(owner, proto_msg) {} // lint !e1744

void NamedAttrs::SetName(const std::string &name) {
auto proto_msg = named_attrs_.GetProtoMsg();
@@ -238,7 +239,7 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>)
@@ -252,9 +253,11 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>)
/*lint -e665*/
ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>)
ATTR_VALUE_SET_GET_IMP(vector<DataType>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE)
/*lint +e665*/
ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665
ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665

#undef ATTR_VALUE_SET_GET_IMP

@@ -782,14 +785,14 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
if (graph_def == nullptr) {
GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
graph_def = nullptr;
return false;
return false; // lint !e665
} else {
ModelSerializeImp imp;
imp.SetProtobufOwner(graph_def);
if (!imp.UnserializeGraph(graph, *graph_def)) {
GELOGE(GRAPH_FAILED, "UnserializeGraph Failed");
return false;
}
} // lint !e514
value = graph;
}
return true;
@@ -809,7 +812,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
if (graph_def == nullptr) {
GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
graph_def = nullptr;
return false;
return false; // lint !e665
} else {
ComputeGraphPtr graph = nullptr;
ModelSerializeImp imp;
@@ -817,7 +820,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
if (!imp.UnserializeGraph(graph, *graph_def)) {
GELOGE(GRAPH_FAILED, "UnserializeGraph Failed");
return false;
}
} // lint !e514
value.push_back(graph);
}
}
@@ -969,7 +972,9 @@ ATTR_UTILS_SET_IMP(Tensor, GeTensor)
ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS)
ATTR_UTILS_SET_GET_IMP(Bytes, Buffer)
ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr)
/*lint -e665*/
ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>)
/*lint +e665*/

ATTR_UTILS_SET_GET_IMP(ListInt, vector<int64_t>)
ATTR_UTILS_SET_IMP(ListInt, vector<int32_t>)
@@ -984,8 +989,8 @@ ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>)
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>)
ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>)
ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>)
ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>)
ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType)
ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) // lint !e665
ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665

bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name,
std::initializer_list<ConstGeTensorPtr> &&value) {
@@ -1154,7 +1159,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(Con
}
for (const auto &item : bytes_vals) {
ModelSerialize serialize;
auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize());
auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732
value.push_back(op_desc);
}
return true;
@@ -1206,7 +1211,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(
op_def = ComGraphMakeShared<proto::OpDef>();
if (op_def == nullptr) {
GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
return nullptr;
return nullptr; // lint !e665
}
ModelSerializeImp imp;
(void)imp.SerializeOpDesc(org_op_desc, op_def.get());
@@ -1216,27 +1221,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(
GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed");
op_desc->extAttrs_ = org_op_desc->extAttrs_;

if (op_desc->HasAttr("_input_name_idx_key")) {
if (op_desc->DelAttr("_input_name_idx_key") != SUCCESS) {
GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_key failed.");
}
}

if (op_desc->HasAttr("_input_name_idx_value")) {
if (op_desc->DelAttr("_input_name_idx_value") != SUCCESS) {
GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_value failed.");
}
// This function may be called by some passes of fusion engine, in this condition, do not need these attribute
if (!op_desc->input_name_idx_.empty()) {
op_desc->input_name_idx_.clear();
}

if (op_desc->HasAttr("_opt_input")) {
if (op_desc->DelAttr("_opt_input") != SUCCESS) {
GELOGE(GRAPH_FAILED, "DelAttr _opt_input failed.");
}
}

if (!op_desc->output_name_idx_.empty()) {
op_desc->output_name_idx_.clear();
}
if (!op_desc->optional_input_names_.empty()) {
op_desc->optional_input_names_.clear();
}

return op_desc;
}
@@ -1260,6 +1254,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c

op_desc->extAttrs_ = org_op_desc->extAttrs_;

op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end());
op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(),
org_op_desc->optional_input_names_.end());
op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end());

op_desc->infer_func_ = org_op_desc->infer_func_;


+ 11
- 0
src/common/graph/ge_tensor.cc View File

@@ -220,6 +220,7 @@ const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape";
const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format";
const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type";
const string TENSOR_UTILS_SHAPE_RANGE = "shape_range";
const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index";

GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {}

@@ -567,6 +568,16 @@ DataType GeTensorDesc::GetOriginDataType() const {
return TypeUtils::SerialStringToDataType(origin_data_type_str);
}

std::vector<uint32_t> GeTensorDesc::GetRefPortIndex() const {
vector<uint32_t> ref_port_index;
(void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index);
return ref_port_index;
}

void GeTensorDesc::SetRefPortByIndex(const std::vector<uint32_t> &index) {
(void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index);
}

graphStatus GeTensorDesc::IsValid() const {
auto dtype = this->GetDataType();
auto format = this->GetFormat();


+ 1
- 1
src/common/graph/graph.cc View File

@@ -210,7 +210,7 @@ class GraphImpl {

graphStatus FindOpByName(const string &name, ge::Operator &op) const {
auto it = op_list_.find(name);
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "Error: there is no op: %s.", name.c_str());
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str());
op = it->second;
return GRAPH_SUCCESS;
}


+ 68
- 8
src/common/graph/graph.mk View File

@@ -77,6 +77,7 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -94,10 +95,36 @@ LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/atc/lib64/stub/graph.cc \
../../out/atc/lib64/stub/operator.cc \
../../out/atc/lib64/stub/tensor.cc \
../../out/atc/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for host
include $(CLEAR_VARS)
LOCAL_MODULE := fwk_stub/libgraph

LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2
LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/inference_context.cc \


LOCAL_SHARED_LIBRARIES :=
@@ -122,6 +149,7 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -142,10 +170,39 @@ LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/atc/lib64/stub/graph.cc \
../../out/atc/lib64/stub/operator.cc \
../../out/atc/lib64/stub/tensor.cc \
../../out/atc/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

ifeq ($(device_os),android)
LOCAL_LDFLAGS := -ldl
endif

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_SHARED_LIBRARY)

#compiler for device
include $(CLEAR_VARS)
LOCAL_MODULE := fwk_stub/libgraph

LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/inference_context.cc \


LOCAL_SHARED_LIBRARIES :=
@@ -174,6 +231,7 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -199,6 +257,7 @@ LOCAL_STATIC_LIBRARIES := \
LOCAL_SHARED_LIBRARIES := \
libc_sec \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -222,6 +281,7 @@ LOCAL_STATIC_LIBRARIES := \
LOCAL_SHARED_LIBRARIES := \
libc_sec \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl



+ 97
- 27
src/common/graph/model_serialize.cc View File

@@ -88,10 +88,8 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_
}

bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) {
if (op_desc == nullptr || op_def_proto == nullptr) {
GELOGE(GRAPH_FAILED, "Input Para Invalid");
return false;
}
GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null.");
GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
if (op_desc->op_def_.GetProtoMsg() != nullptr) {
*op_def_proto = *op_desc->op_def_.GetProtoMsg();
// Delete unnecessary attr
@@ -130,18 +128,40 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
op_def_proto->add_subgraph_name(name);
}
OpDescToAttrDef(op_desc, op_def_proto);
}
return true;
}

proto::AttrDef key;
proto::AttrDef value;
void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) {
proto::AttrDef key_in;
proto::AttrDef value_in;
auto op_desc_attr = op_def_proto->mutable_attr();
if (!op_desc->input_name_idx_.empty()) {
for (auto &item : op_desc->input_name_idx_) {
key_in.mutable_list()->add_s(item.first);
value_in.mutable_list()->add_i(item.second);
}
op_desc_attr->insert({"_input_name_key", key_in});
op_desc_attr->insert({"_input_name_value", value_in});
}
proto::AttrDef key_out;
proto::AttrDef value_out;
if (!op_desc->output_name_idx_.empty()) {
for (auto &item : op_desc->output_name_idx_) {
key.mutable_list()->add_s(item.first);
value.mutable_list()->add_i(item.second);
key_out.mutable_list()->add_s(item.first);
value_out.mutable_list()->add_i(item.second);
}
auto op_desc_attr = op_def_proto->mutable_attr();
op_desc_attr->insert({"_output_name_key", key});
op_desc_attr->insert({"_output_name_value", value});
op_desc_attr->insert({"_output_name_key", key_out});
op_desc_attr->insert({"_output_name_value", value_out});
}
proto::AttrDef opt_input;
if (!op_desc->optional_input_names_.empty()) {
for (auto &item : op_desc->optional_input_names_) {
opt_input.mutable_list()->add_s(item);
}
op_desc_attr->insert({"_opt_input", opt_input});
}
return true;
}

bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) {
@@ -237,13 +257,70 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Unseriali
}
}

void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector<string> &key_in, std::vector<string> &key_out,
std::vector<uint32_t> &value_in, std::vector<uint32_t> &value_out,
std::vector<string> &opt_input) {
if (!key_in.empty()) {
if (key_in.size() != value_in.size()) {
GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(),
value_in.size());
} else {
for (uint32_t i = 0; i < key_in.size(); ++i) {
op_desc->input_name_idx_.insert(std::pair<string, uint32_t>(key_in.at(i), value_in.at(i)));
}
}
}
if (!key_out.empty()) {
if (key_out.size() != value_out.size()) {
GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(),
value_out.size());
} else {
for (uint32_t i = 0; i < key_out.size(); ++i) {
op_desc->output_name_idx_.insert(std::pair<string, uint32_t>(key_out.at(i), value_out.at(i)));
}
}
}
if (!opt_input.empty()) {
for (const auto &i : opt_input) {
op_desc->optional_input_names_.insert(i);
}
}
}

bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) {
std::vector<string> key;
std::vector<uint32_t> value;
std::vector<string> opt_input;
std::vector<string> key_in;
std::vector<uint32_t> value_in;
if (op_def_proto.attr().count("_opt_input") > 0) {
auto &name_list = op_def_proto.attr().at("_opt_input").list();
for (const auto &item_s : name_list.s()) {
opt_input.push_back(item_s);
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_opt_input");
}
if (op_def_proto.attr().count("_input_name_key") > 0) {
auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list();
for (const auto &item_s : output_name_key_list.s()) {
key_in.push_back(item_s);
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_input_name_key");
}
if (op_def_proto.attr().count("_input_name_value") > 0) {
auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list();
for (const auto &item_i : input_name_value_list.i()) {
value_in.push_back(static_cast<uint32_t>(item_i));
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_input_name_value");
}
std::vector<string> key_out;
std::vector<uint32_t> value_out;
if (op_def_proto.attr().count("_output_name_key") > 0) {
auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list();
for (const auto &item_s : output_name_key_list.s()) {
key.push_back(item_s);
key_out.push_back(item_s);
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_output_name_key");
@@ -251,7 +328,7 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d
if (op_def_proto.attr().count("_output_name_value") > 0) {
auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list();
for (const auto &item_i : output_name_value_list.i()) {
value.push_back(static_cast<uint32_t>(item_i));
value_out.push_back(static_cast<uint32_t>(item_i));
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_output_name_value");
@@ -282,15 +359,8 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d
op_desc->SetSubgraphInstanceName(graph_index++, name);
}

if (key.size() != 0) {
if (key.size() != value.size()) {
GELOGE(GRAPH_FAILED, "twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size());
} else {
for (uint32_t i = 0; i < key.size(); ++i) {
op_desc->output_name_idx_.insert(std::pair<string, uint32_t>(key.at(i), value.at(i)));
}
}
}
// insert name index by key and value
AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input);

return true;
}
@@ -338,13 +408,13 @@ bool ModelSerializeImp::HandleNodeNameRef() {
item.dst_node_name.c_str(), item.dst_in_index);
return false;
}
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed.");
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
} else {
// Control edge
auto src_anchor = src_node_it->second->GetOutControlAnchor();
auto dst_anchor = item.dst_node->GetInControlAnchor();
if (src_anchor != nullptr && dst_anchor != nullptr) {
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed.");
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
}
}
}


+ 38
- 19
src/common/graph/node.cc View File

@@ -26,6 +26,7 @@
#include "utils/ge_ir_utils.h"
#include "utils/node_utils.h"
#include "utils/op_desc_utils.h"
#include "common/util/error_manager/error_manager.h"

using std::string;
using std::vector;
@@ -154,7 +155,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(cons
const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode();
const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode();
if (peer_node == nullptr || r_peer_node == nullptr) {
GELOGE(GRAPH_FAILED, "Error: anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ",
GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ",
this->GetName().c_str(), i, j);
return false;
}
@@ -434,8 +435,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const {
if (idx < 0 || idx >= static_cast<int>(in_data_anchors_.size())) {
GELOGE(GRAPH_FAILED, "the node doesn't have %d th in_data_anchor, node %s:%s", idx, GetType().c_str(),
GetName().c_str());
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
return nullptr;
} else {
return in_data_anchors_[idx];
@@ -445,7 +449,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAn
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const {
// Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_
if (idx < -1 || idx >= static_cast<int>(in_data_anchors_.size())) {
GELOGW("the node doesn't have %d th in_anchor, node %s:%s", idx, GetType().c_str(), GetName().c_str());
GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str());
return nullptr;
} else {
// Return control anchor
@@ -461,8 +465,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int i
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const {
// Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_
if (idx < -1 || idx >= static_cast<int>(out_data_anchors_.size())) {
GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_anchor, node %s:%s", idx, GetType().c_str(),
GetName().c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"},
{
GetName().c_str(),
std::to_string(idx),
"out_anchor",
GetType().c_str(),
});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
return nullptr;
} else {
// Return control anchor
@@ -477,8 +488,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const {
if (idx < 0 || idx >= static_cast<int>(out_data_anchors_.size())) {
GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_data_anchor, node %s:%s", idx, GetType().c_str(),
GetName().c_str());
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
return nullptr;
} else {
return out_data_anchors_[idx];
@@ -726,22 +740,27 @@ graphStatus Node::Verify() const {
const string aipp_data_type = "AippData";
const string const_type = "Const";
const string variable_type = "Variable";
bool is_unknown_graph = GetOwnerComputeGraph()->GetGraphUnknownFlag();
GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr");

for (const auto &in_anchor_ptr : GetAllInDataAnchors()) {
if (in_anchor_ptr == nullptr) {
GELOGW("in anchor ptr is null");
continue;
if (!is_unknown_graph) {
for (const auto &in_anchor_ptr : GetAllInDataAnchors()) {
GE_IF_BOOL_EXEC(in_anchor_ptr == nullptr, GELOGW("in anchor ptr is null"); continue);
bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type ||
op_->GetType() == const_type || op_->GetType() == variable_type ||
op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0;
if (!valid_anchor) {
ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"},
{GetName(), std::to_string(in_anchor_ptr->GetIdx())});
GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx());
return GRAPH_FAILED;
}
}
GE_CHK_BOOL_RET_STATUS(
op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type ||
op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) ||
in_anchor_ptr->GetPeerAnchors().size() > 0,
GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx());
}

string frameworkop_type = "FrameworkOp";
if (op_->GetType() != frameworkop_type) {
bool need_update_name = op_->GetType() != frameworkop_type && !is_unknown_graph;
if (need_update_name) {
auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_->GetType());
if (node_op.IsEmpty()) {
GELOGW("get op from OperatorFactory fail. opType: %s", op_->GetType().c_str());
@@ -761,7 +780,7 @@ graphStatus Node::Verify() const {
}
node_op.BreakConnect();
}
GE_IF_BOOL_EXEC(is_unknown_graph, return GRAPH_SUCCESS;);
if (op_->CommonVerify() == GRAPH_SUCCESS) {
Operator op_proxy = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this());
auto verify_func = op_->GetVerifyFunc();


+ 89
- 112
src/common/graph/op_desc.cc View File

@@ -19,6 +19,7 @@
#include "debug/ge_util.h"
#include "external/graph/operator.h"
#include "framework/common/debug/ge_log.h"
#include "common/util/error_manager/error_manager.h"
#include "graph/ge_attr_value.h"
#include "graph/ge_tensor.h"
#include "graph/operator_factory_impl.h"
@@ -32,6 +33,7 @@ using std::shared_ptr;
using std::string;
using std::vector;

/*lint -save -e521 -e681 -e732 -e737*/
namespace ge {
const std::string ATTR_NAME_ID = "id";

@@ -63,12 +65,6 @@ const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const";

const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends";

const std::string ATTR_NAME_OPT_INPUT = "_opt_input";

const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key";

const std::string ATTR_NAME_INPUT_NAME_IDX_VALUE = "_input_name_idx_value";

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() {
op_def_.InitDefault();
if (op_def_.GetProtoMsg() != nullptr) {
@@ -210,8 +206,7 @@ graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_d
}

graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) {
auto input_name_idx = GetAllInputName();
if (input_name_idx.find(name) != input_name_idx.end()) {
if (input_name_idx_.find(name) != input_name_idx_.end()) {
GELOGI("input %s is exist, update it", name.c_str());
graphStatus ret = UpdateInputDesc(name, input_desc);
return ret;
@@ -223,17 +218,15 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp
return GRAPH_FAILED;
}
inputs_desc_.push_back(in_desc);
(void)input_name_idx.insert(make_pair(name, index));
SetAllInputName(input_name_idx);
(void)input_name_idx_.insert(make_pair(name, index));
return GRAPH_SUCCESS;
}
}

graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) {
auto input_name_idx = GetAllInputName();
for (unsigned int i = 0; i < num; i++) {
string input_name = name + std::to_string(i);
GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED,
GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED,
"Add input tensor_desc is existed. name[%s]", input_name.c_str());

std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
@@ -250,24 +243,22 @@ graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int nu
(void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc);

// Update index in input_name_idx
for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) {
for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) {
if (it->second >= (index + i)) {
it->second += 1;
}
}

(void)input_name_idx.insert(make_pair(input_name, i + index));
(void)input_name_idx_.insert(make_pair(input_name, i + index));
}
SetAllInputName(input_name_idx);

return GRAPH_SUCCESS;
}

graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) {
auto input_name_idx = GetAllInputName();
for (unsigned int i = 0; i < num; i++) {
string input_name = name + std::to_string(i);
GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED,
GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED,
"Add input tensor_desc is existed. name[%s]", input_name.c_str());

std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
@@ -278,13 +269,12 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n
(void)inputs_desc_.insert(inputs_desc_.begin(), in_desc);

// Update index in input_name_idx
for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) {
for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) {
it->second += 1;
}

(void)input_name_idx.insert(make_pair(input_name, 0));
(void)input_name_idx_.insert(make_pair(input_name, 0));
}
SetAllInputName(input_name_idx);

return GRAPH_SUCCESS;
}
@@ -315,19 +305,10 @@ graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int

graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) {
if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED;
vector<string> optional_input_names;
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names);
optional_input_names.push_back(name);
(void)AttrUtils::SetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names);
(void)optional_input_names_.insert(name);
return GRAPH_SUCCESS;
}

std::vector<string> OpDesc::GetAllOptionalInputName() const {
vector<string> optional_input_names;
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names);
return optional_input_names;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) {
GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index);
@@ -342,12 +323,11 @@ OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) {
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const {
return (
IsEqual(this->GetAllInputName(), r_op_desc.GetAllInputName(), "OpDesc.GetAllInputName()") &&
IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") &&
IsEqual(this->GetAllOptionalInputName(), r_op_desc.GetAllOptionalInputName(), "OpDesc.GetAllOptionalInputName()") &&
IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") &&
IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_"));
return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") &&
IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") &&
IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") &&
IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") &&
IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_"));
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const {
@@ -421,9 +401,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpD
}

graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
if (it == input_name_idx.end()) {
auto it = input_name_idx_.find(name);
if (it == input_name_idx_.end()) {
GELOGW("Cann't find the input desc. name[%s]", name.c_str());
return GRAPH_FAILED;
}
@@ -443,9 +422,8 @@ graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &
}

bool OpDesc::InputIsSet(const string &name) const {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
if (it != input_name_idx.end()) {
auto it = input_name_idx_.find(name);
if (it != input_name_idx_.end()) {
GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false);
auto tensor_desc = inputs_desc_[it->second];
GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false);
@@ -463,20 +441,40 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc
}

GeTensorDesc OpDesc::GetInputDesc(const string &name) const {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), GeTensorDesc());
auto it = input_name_idx_.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc());
GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc());
return *(inputs_desc_[it->second].get());
}

GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const {
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const {
GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index);
if (inputs_desc_[index] == nullptr) {
return nullptr;
}
if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) {
GELOGW("input desc is invalid");
return nullptr;
}
return inputs_desc_[index];
}

GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
if (it == input_name_idx.end()) {
GELOGW("Failed to get [%s] input desc", name.c_str());
return nullptr;
}
return MutableInputDesc(it->second);
}

GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const {
vector<string> names;
if (input_name_idx.empty()) {
if (input_name_idx_.empty()) {
return OpDesc::Vistor<string>(shared_from_this(), names);
}
for (std::pair<string, uint32_t> input : input_name_idx) {
for (std::pair<string, uint32_t> input : input_name_idx_) {
names.push_back(input.first);
}
return OpDesc::Vistor<string>(shared_from_this(), names);
@@ -496,15 +494,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(cons

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const {
GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index);
if (inputs_desc_[index] == nullptr) {
return nullptr;
}
GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid");
return inputs_desc_[index];
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllInputsDesc() const {
vector<GeTensorDesc> temp{};
for (const auto &it : inputs_desc_) {
@@ -609,6 +598,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu
return outputs_desc_[index];
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const {
auto it = output_name_idx_.find(name);
if (it == output_name_idx_.end()) {
GELOGW("Failed to get [%s] output desc", name.c_str());
return nullptr;
}
return MutableOutputDesc(it->second);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const {
return static_cast<uint32_t>(outputs_desc_.size());
}
@@ -652,9 +650,8 @@ OpDesc::GetInputDescPtrDfault(uint32_t index) const {
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), shared_ptr<const GeTensorDesc>());
auto it = input_name_idx_.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr<const GeTensorDesc>());
return inputs_desc_[it->second];
}

@@ -687,47 +684,26 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int
return GRAPH_SUCCESS;
}

bool OpDesc::IsOptionalInput(const string &name) const {
vector<string> optional_input_names;
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names);
for (auto &item : optional_input_names) {
if (item == name) {
return true;
}
void OpDesc::RemoveInputDesc(uint32_t index) {
while (inputs_desc_.size() > index) {
inputs_desc_.pop_back();
}
return false;
}

bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); }

std::map<string, uint32_t> OpDesc::GetAllInputName() const {
std::map<string, uint32_t> input_name_idx;
std::vector<string> key;
std::vector<uint32_t> value;
(void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key);
(void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value);

if (key.size() != value.size()) {
GE_LOGE("twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size());
} else {
for (uint32_t i = 0; i < key.size(); ++i) {
input_name_idx.insert(std::pair<string, uint32_t>(key.at(i), value.at(i)));
}
void OpDesc::RemoveOutputDesc(uint32_t index) {
while (outputs_desc_.size() > index) {
outputs_desc_.pop_back();
}
return input_name_idx;
}

void OpDesc::SetAllInputName(const std::map<string, uint32_t> &input_name_idx) {
std::vector<string> key;
std::vector<uint32_t> value;
for (auto &item : input_name_idx) {
key.emplace_back(item.first);
value.emplace_back(item.second);
}
(void)AttrUtils::SetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key);
(void)AttrUtils::SetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value);
bool OpDesc::IsOptionalInput(const string &name) const {
return optional_input_names_.find(name) != optional_input_names_.end();
}

bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); }

std::map<string, uint32_t> OpDesc::GetAllInputName() const { return input_name_idx_; }

std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; }

bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) {
@@ -737,7 +713,6 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) {
auto factory_map_size = input_name_idx.size();
// It indicates that some inputs have no optionalname.
// The redundant optionalname of factory needs to be deleted and then assigned
auto all_input_name_idx = GetAllInputName();
if (input_map_size < factory_map_size) {
GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size,
factory_map_size);
@@ -750,18 +725,17 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) {
}
if (input_name_idx.size() == input_map_size) {
GELOGI("UpdateInputName");
all_input_name_idx = input_name_idx;
input_name_idx_ = input_name_idx;
} else {
ret = false;
GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size());
}
} else if (input_map_size == factory_map_size) {
all_input_name_idx = input_name_idx;
input_name_idx_ = input_name_idx;
} else {
ret = false;
GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size);
}
SetAllInputName(all_input_name_idx);
return ret;
}

@@ -882,36 +856,41 @@ graphStatus OpDesc::CommonVerify() const {
// Checking shape of all inputs
vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims();
for (int64_t dim : ishape) {
GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.",
iname.c_str());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dim < -2, ErrorManager::GetInstance().ATCReportErrMessage(
"E19014", {"opname", "value", "reason"},
{GetName(), "input " + iname + " shape", "contains negative or zero dimension"});
return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(),
iname.c_str());
}
}
// Check all attributes defined
const auto &all_attributes = GetAllAttrs();
for (const auto &name : GetAllAttrNames()) {
GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED,
"operator attribute %s is empty.", name.c_str());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
all_attributes.find(name) == all_attributes.end(),
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{GetName(), "attribute " + name, "is empty"});
return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str());
}

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.begin();
for (; it != input_name_idx.end(); ++it) {
auto it = input_name_idx_.begin();
for (; it != input_name_idx_.end(); ++it) {
if (it->second == index) {
break;
}
}
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), "");
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), "");
return it->first;
}

int OpDesc::GetInputIndexByName(const string &name) const {
auto input_name_idx = GetAllInputName();
auto it_find = input_name_idx.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx.end(), -1);
auto it_find = input_name_idx_.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1);
return static_cast<int>(it_find->second);
}

@@ -1204,12 +1183,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<bool> OpDesc::GetIsInputCo

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name,
const int &index) {
auto input_name_idx = GetAllInputName();
if (input_name_idx.find(name) != input_name_idx.end()) {
if (input_name_idx_.find(name) != input_name_idx_.end()) {
GELOGI("Restore input name index is existed. name[%s]", name.c_str());
}
(void)input_name_idx.insert(make_pair(name, index));
SetAllInputName(input_name_idx);
(void)input_name_idx_.insert(make_pair(name, index));
return GRAPH_SUCCESS;
}



+ 46
- 48
src/common/graph/operator.cc View File

@@ -21,7 +21,7 @@
#include <mutex>
#include <queue>
#include <set>
#include "array_ops.h"
#include "./array_ops.h"
#include "debug/ge_log.h"
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
@@ -36,6 +36,8 @@
#include "graph/op_desc.h"
#include "graph/runtime_inference_context.h"
#include "graph/usr_types.h"
#include "graph/utils/node_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "utils/graph_utils.h"
#include "utils/op_desc_utils.h"
#include "utils/tensor_adapter.h"
@@ -54,11 +56,13 @@ using std::string;
using std::to_string;
using std::vector;

/*lint -save -e529 -e728*/
/*lint -e446 -e732*/
/*lint -e665*/
namespace ge {
class OpIO {
public:
explicit OpIO(const string &name, int index, const OperatorImplPtr &owner)
: name_(name), index_(index), owner_(owner) {}
OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {}

~OpIO() = default;

@@ -546,56 +550,46 @@ Operator &Operator::AddControlInput(const Operator &src_oprt) {
}

graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const {
if (operator_impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "operator impl is nullptr.");
return GRAPH_FAILED;
}
ge::ConstNodePtr node_ptr = operator_impl_->GetNode();
if (node_ptr) {
GE_CHECK_NOTNULL(operator_impl_);
auto node_ptr = operator_impl_->GetNode();
if (node_ptr != nullptr) {
// For inner compute graph
auto op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "op_desc is nullptr.");
return GRAPH_FAILED;
}
GE_CHECK_NOTNULL(op_desc);
auto index = op_desc->GetInputIndexByName(dst_name);
auto in_data_anchor = node_ptr->GetInDataAnchor(index);
if (in_data_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "in_data_anchor is nullptr.");
return GRAPH_FAILED;
}
GE_CHECK_NOTNULL(in_data_anchor);
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (out_data_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "out_data_anchor is nullptr.");
return GRAPH_FAILED;
}
std::shared_ptr<Node> peer_node_ptr = out_data_anchor->GetOwnerNode();
if (peer_node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "peer_node_ptr is nullptr.");
return GRAPH_FAILED;
}
ge::OperatorImplPtr operator_impl_ptr = nullptr;
operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(peer_node_ptr);
if (operator_impl_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
return GRAPH_FAILED;
}
Operator const_op(std::move(operator_impl_ptr));
if (peer_node_ptr->GetOpDesc() != nullptr) {
const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType();
if (op_descType == CONSTANTOP) {
return const_op.GetAttr(op::Constant::name_attr_value(), data);
} else if (op_descType == CONSTANT) {
return const_op.GetAttr(op::Const::name_attr_value(), data);
GE_CHECK_NOTNULL(out_data_anchor);
auto peer_node = out_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(peer_node);
auto peer_op_desc = peer_node->GetOpDesc();
GE_CHECK_NOTNULL(peer_op_desc);
auto peer_op_type = peer_op_desc->GetType();
if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) {
auto const_op_impl = ComGraphMakeShared<OperatorImpl>(peer_node);
GE_CHECK_NOTNULL(const_op_impl);
Operator const_op(std::move(const_op_impl));
return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
} else if (peer_op_type == DATA) {
auto parent_node = NodeUtils::GetParentInput(peer_node);
while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
parent_node = NodeUtils::GetParentInput(parent_node);
}
if ((parent_node != nullptr) &&
((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
auto const_op_impl = ComGraphMakeShared<OperatorImpl>(parent_node);
GE_CHECK_NOTNULL(const_op_impl);
Operator const_op(std::move(const_op_impl));
return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
}
}

// Try get from runtime inference context
auto session_id = std::to_string(GetContext().SessionId());
RuntimeInferenceContext *runtime_infer_ctx = nullptr;
if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) {
GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str());
auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data);
auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data);
if (ret == GRAPH_SUCCESS) {
return GRAPH_SUCCESS;
}
@@ -604,6 +598,8 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co
// For outer graph
return GetInputConstDataOut(dst_name, data);
}
auto op_name = operator_impl_->GetName();
GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str());
return GRAPH_FAILED;
}
graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const {
@@ -914,7 +910,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; }
GELOGW("set attr name %s failed.", name.c_str()); \
} \
return *this; \
}
} // lint !e665

#define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \
graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \
@@ -927,7 +923,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; }
return GRAPH_FAILED; \
} \
return GRAPH_SUCCESS; \
}
} // lint !e665

void Operator::BreakConnect() const {
if (operator_impl_ == nullptr) {
@@ -948,7 +944,7 @@ void Operator::BreakConnect() const {
if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
GELOGW("reg attr name %s failed.", name.c_str()); \
} \
}
} // lint !e665

OP_ATTR_SET_IMP(int64_t, Int)
OP_ATTR_SET_IMP(int32_t, Int)
@@ -969,22 +965,22 @@ OP_ATTR_SET_IMP(const vector<vector<int64_t>> &, ListListInt)
OP_ATTR_SET_IMP(float, Float)
OP_ATTR_GET_IMP(float &, Float)
OP_ATTR_SET_IMP(const vector<float> &, ListFloat)
OP_ATTR_GET_IMP(vector<float> &, ListFloat)
OP_ATTR_GET_IMP(vector<float> &, ListFloat) // lint !e665

OP_ATTR_SET_IMP(bool, Bool)
OP_ATTR_GET_IMP(bool &, Bool)
OP_ATTR_SET_IMP(const vector<bool> &, ListBool)
OP_ATTR_GET_IMP(vector<bool> &, ListBool)
OP_ATTR_GET_IMP(vector<bool> &, ListBool) // lint !e665

OP_ATTR_SET_IMP(const string &, Str)
OP_ATTR_GET_IMP(string &, Str)
OP_ATTR_SET_IMP(const vector<string> &, ListStr)
OP_ATTR_GET_IMP(vector<string> &, ListStr)
OP_ATTR_GET_IMP(vector<string> &, ListStr) // lint !e665

OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs)
OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs)
OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) // lint !e665

OP_ATTR_REG_IMP(int64_t, Int)
OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt)
@@ -1547,3 +1543,5 @@ void GraphUtils::BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_node
}
}
} // namespace ge
/*lint +e446 +e732*/
/*lint +e665*/

+ 2
- 0
src/common/graph/opsproto/opsproto_manager.cc View File

@@ -31,7 +31,9 @@ OpsProtoManager *OpsProtoManager::Instance() {
}

bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &options) {
/*lint -e1561*/
auto proto_iter = options.find("ge.opsProtoLibPath");
/*lint +e1561*/
if (proto_iter == options.end()) {
GELOGW("ge.opsProtoLibPath option not set, return.");
return false;


+ 2
- 0
src/common/graph/option/ge_context.cc View File

@@ -85,6 +85,8 @@ uint32_t GEContext::DeviceId() { return device_id_; }

uint64_t GEContext::TraceId() { return trace_id_; }

void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; }

void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; }

} // namespace ge

+ 47
- 4
src/common/graph/ref_relation.cc View File

@@ -37,7 +37,7 @@ const string kWhile = "While";
const string kIf = "If";
const string kCase = "Case";

const int kMaxElementNum = 100;
const uint16_t kMaxElementNum = 100;

std::unordered_set<string> function_op = {kWhile, kIf, kCase};
} // namespace
@@ -170,6 +170,7 @@ graphStatus RefRelations::Impl::BuildRefRelationsForWhile(
// data_nodes has been sorted
// for while, input num must be same as output num
auto input_num = root_node->GetAllInDataAnchorsSize();
NodePtr netoutput = nullptr;

size_t ref_i = 0;
while (ref_i < input_num) {
@@ -212,10 +213,44 @@ graphStatus RefRelations::Impl::BuildRefRelationsForWhile(
cell_netoutput_in.in_out = NODE_IN;
cell_netoutput_in.in_out_idx = ele.second;
ref_i_all_refs.emplace_back(cell_netoutput_in);
netoutput = ele.first;
}
node_refs.emplace_back(ref_i_all_refs);
ref_i++;
}
/* There exist scene like the follows, it means data0 data1 netoutput 0'th
* and 1'th tensor should be the same addr.
* Data0 Data1
* \/
* /\
* netoutput
*/
if (netoutput == nullptr) {
return GRAPH_SUCCESS;
}
for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) {
auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
if (peer_out_data_anchor == nullptr) {
continue;
}
auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str());
continue;
}
if (peer_out_data_node->GetType() != DATA) {
continue;
}
auto in_data_anchor_idx = in_anchor->GetIdx();
auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx));
int ref_d;
int ref_n;
(void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d);
(void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n);

node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end());
node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end());
}

return GRAPH_SUCCESS;
}
@@ -242,6 +277,10 @@ void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &r
int sub_graph_idx = 0;
for (const auto &name : sub_graph_names) {
auto sub_graph = root_graph.GetSubgraph(name);
if (sub_graph == nullptr) {
GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str());
continue;
}
for (const auto &sub_graph_node : sub_graph->GetDirectNode()) {
auto sub_graph_node_type = sub_graph_node->GetType();

@@ -296,6 +335,9 @@ graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_n
data_nodes.pop_back();
int ref_idx = 0;
(void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx);
if (ref_idx >= static_cast<int>(classed_data_nodes.size())) {
return GRAPH_FAILED;
}
classed_data_nodes[ref_idx].emplace_back(data);
}
return GRAPH_SUCCESS;
@@ -317,7 +359,7 @@ graphStatus RefRelations::Impl::ProcessSubgraphNetoutput(
}
int ref_o;
if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) {
if (ref_o >= kMaxElementNum) {
if (ref_o >= static_cast<int>(classed_netoutput_nodes.size())) {
return GRAPH_FAILED;
}
classed_netoutput_nodes[ref_o].emplace_back(
@@ -349,8 +391,9 @@ graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) {
vector<NodePtr> netoutput_nodes;
// Get data and netoutput of sub_graph
GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type);
vector<vector<NodePtr>> classed_data_nodes(kMaxElementNum); // according to ref_idx
vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(kMaxElementNum); // according to ref_idx
size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum;
vector<vector<NodePtr>> classed_data_nodes(max_elem_num); // according to ref_idx
vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(max_elem_num); // according to ref_idx
status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "classfy data nodes failed!");


+ 1
- 0
src/common/graph/runtime_inference_context.cc View File

@@ -30,6 +30,7 @@ graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id
return GRAPH_FAILED;
}

std::lock_guard<std::mutex> lk(ctx_mu_);
auto emplace_ret = contexts_.emplace(context_id, std::move(ctx));
if (!emplace_ret.second) {
GELOGE(GRAPH_FAILED, "Old context not destroyed");


+ 257
- 36
src/common/graph/shape_refiner.cc View File

@@ -37,6 +37,162 @@

namespace ge {
namespace {
const uint32_t kWhileBodySubGraphIdx = 1;

graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) {
GELOGD("Enter reverse brush while body subgraph process!");

auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx);
if (sub_graph_body == nullptr) {
GELOGE(GRAPH_FAILED, "Get while body graph failed!");
return GRAPH_FAILED;
}

for (const auto &node_sub : sub_graph_body->GetAllNodes()) {
for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) {
auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i);
(void)input_desc->SetUnknownDimNumShape();
}
for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) {
auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i);
(void)output_desc->SetUnknownDimNumShape();
}
}

return GRAPH_SUCCESS;
}

graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
// check sub_graph shape. Get max for update.
for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
if (ref_out_tensors[i].empty()) {
continue;
}

int64_t max_size = 0;
size_t max_shape_index = 0;
auto &ref_out_tensor = ref_out_tensors[i].at(0);
const auto &ref_out_tensor_shape = ref_out_tensor.MutableShape();
for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
auto &tensor = ref_out_tensors[i].at(j);
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
return GRAPH_FAILED;
}

auto shape = tensor.MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGE(GRAPH_FAILED, "node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu",
node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
return GRAPH_FAILED;
}

int64_t size = 1;
for (auto dim : shape.GetDims()) {
if (INT64_MAX / dim < size) {
GELOGE(PARAM_INVALID, "The shape size overflow");
return PARAM_INVALID;
}
size *= dim;
}

if (size > max_size) {
max_size = size;
max_shape_index = j;
}
}

(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
}

return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
GELOGD("Enter update parent node shape for class branch op process");
if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
return UpdataOutputForMultiBatcch(node, ref_out_tensors);
}

// check sub_graph shape.If not same ,do unknown shape process
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].empty()) {
continue;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
for (auto &tensor : ref_out_tensors[i]) {
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
return GRAPH_FAILED;
}
auto shape = tensor.MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
break;
}
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
continue;
}
GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
}
}
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
}
return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
GELOGD("Enter update parent node shape for class while op process");
if (ref_data_tensors.size() != ref_out_tensors.size()) {
GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(),
ref_data_tensors.size(), ref_out_tensors.size());
return GRAPH_FAILED;
}
for (size_t i = 0; i < ref_data_tensors.size(); i++) {
if (ref_out_tensors[i].size() != 1) {
GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!");
return GRAPH_FAILED;
}
}
bool is_need_reverse_brush = false;
// check input and output
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].empty()) {
continue;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
auto tmp_shape = ref_out_tensor.MutableShape();
// ref_i's data and output tensor shape should be same
for (auto &tensor : ref_data_tensors[i]) {
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str());
return GRAPH_FAILED;
}
auto shape = tensor.MutableShape();
if (shape.GetDims() != tmp_shape.GetDims()) {
ref_out_tensor.SetUnknownDimNumShape();
is_need_reverse_brush = true;
break;
}
}
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
}
// reverse refresh while body shape
if (is_need_reverse_brush) {
return ReverseBrushWhileBodySubGraph(node);
}
return GRAPH_SUCCESS;
}

graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
@@ -66,11 +222,14 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
node->GetName().c_str());
return GRAPH_FAILED;
}
if (!AttrUtils::GetInt(node_sub->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
continue;
}
auto input_desc = op_desc->MutableInputDesc(ref_i);
if (input_desc == nullptr) {
GE_LOGE(
@@ -98,6 +257,37 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
}
return GRAPH_SUCCESS;
}

graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput,
const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
auto sub_nodes = sub_graph->GetDirectNode();
for (size_t i = sub_nodes.size(); i > 0; --i) {
auto sub_node = sub_nodes.at(i - 1);
if (sub_node->GetType() == NETOUTPUT) {
netoutput = sub_node;
}
if (sub_node->GetType() == DATA) {
if (sub_node->GetOpDesc() == nullptr) {
return GRAPH_FAILED;
}

int ref_i;
if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
return GRAPH_FAILED;
}
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(),
ref_i, node->GetAllInDataAnchorsSize());
return GRAPH_FAILED;
}
ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
}
}
return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
@@ -105,7 +295,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
return GRAPH_SUCCESS;
}

std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());

for (const auto &name : sub_graph_names) {
if (name.empty()) {
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
@@ -117,13 +310,9 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
return GRAPH_FAILED;
}
NodePtr netoutput = nullptr;
auto sub_nodes = sub_graph->GetDirectNode();
for (size_t i = sub_nodes.size(); i > 0; --i) {
auto sub_node = sub_nodes.at(i - 1);
if (sub_node->GetType() == NETOUTPUT) {
netoutput = sub_node;
break;
}
auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
if (ret != GRAPH_SUCCESS) {
return ret;
}
if (netoutput == nullptr) {
GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str());
@@ -150,22 +339,23 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
continue;
}
GELOGI("Parent node index of edge desc is %d", ref_i);
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i));
if (output_desc == nullptr) {
GE_LOGE(
"The ref index(%d) on the input %d of netoutput %s on the sub graph %s "
"parent node %s are incompatible, outputs num %u",
ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(),
node->GetAllOutDataAnchorsSize());
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
return GRAPH_FAILED;
}
op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc);
ref_out_tensors[ref_i].emplace_back(*edge_desc);
}
}
return GRAPH_SUCCESS;

if (node->GetType() == WHILE) {
return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors);
}
return UpdateParentNodeForBranch(node, ref_out_tensors);
}
} // namespace
void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) {
if (!IsLogEnable(GE, DLOG_DEBUG)) {
return;
}
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "node is null");
return;
@@ -185,6 +375,18 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " ";
}
str += input_desc_str;

input_desc_str = "input origin shape: ";
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
input_desc_str += "[";
for (int64_t dim : input_desc->GetOriginShape().GetDims()) {
input_desc_str += std::to_string(dim) + " ";
}
input_desc_str += "]";
input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" +
TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " ";
}
str += input_desc_str;
}

if (op_desc->GetAllOutputsDescSize() != 0) {
@@ -202,6 +404,21 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " ";
}
str += output_desc_str;

output_desc_str = "output origin shape: ";
for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
if (output_desc == nullptr) {
continue;
}
output_desc_str += "[";
for (int64_t dim : output_desc->GetOriginShape().GetDims()) {
output_desc_str += std::to_string(dim) + " ";
}
output_desc_str += "]";
output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" +
TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " ";
}
str += output_desc_str;
}
GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str());
}
@@ -222,7 +439,6 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &
return ret;
}
}

// Get infer func and execute
ret = op_desc->CallInferFunc(op);
if (ret == GRAPH_PARAM_INVALID) {
@@ -329,6 +545,9 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, Inf
namespace {
std::unordered_map<NodePtr, InferenceContextPtr> context_map;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) {
return InferShapeAndType(node, true);
}
@@ -339,19 +558,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh
GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str());
return GRAPH_FAILED;
}
PrintInOutTensorShape(node, "before_infershape");
Operator op = OpDescUtils::CreateOperatorFromNode(node);

auto inference_context = CreateInferenceContext(context_map, node);
if (inference_context == nullptr) {
GELOGE(GRAPH_FAILED, "inference context is null");
return GRAPH_FAILED;
bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
if (!is_unknown_graph) {
auto inference_context = CreateInferenceContext(context_map, node);
if (inference_context == nullptr) {
GELOGE(GRAPH_FAILED, "inference context is null");
return GRAPH_FAILED;
}
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
op.SetInferenceContext(inference_context);
}

GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());

PrintInOutTensorShape(node, "before_infershape");

Operator op = OpDescUtils::CreateOperatorFromNode(node);
op.SetInferenceContext(inference_context);
graphStatus status = InferShapeAndType(node, op, before_subgraph);
if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
(void)ge::NodeUtils::UpdatePeerNodeInputDesc(node);
@@ -359,16 +579,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh
GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str());
return GRAPH_FAILED;
}

auto ctx_after_infer = op.GetInferenceContext();
if (ctx_after_infer != nullptr) {
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
(void)context_map.emplace(node, ctx_after_infer);
if (!is_unknown_graph) {
auto ctx_after_infer = op.GetInferenceContext();
if (ctx_after_infer != nullptr) {
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
ctx_after_infer->GetMarks().size());
(void)context_map.emplace(node, ctx_after_infer);
}
}
}

PrintInOutTensorShape(node, "after_infershape");

return GRAPH_SUCCESS;


+ 0
- 6
src/common/graph/stub/Makefile View File

@@ -1,6 +0,0 @@
inc_path := $(shell pwd)/inc/external/
out_path := $(shell pwd)/out/atc/lib64/stub/
stub_path := $(shell pwd)/common/graph/stub/

mkdir_stub := $(shell mkdir -p $(out_path))
graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path))

+ 0
- 573
src/common/graph/stub/gen_stubapi.py View File

@@ -1,573 +0,0 @@
import os
import re
import sys
import logging

logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s',
level=logging.INFO)

"""
this attr is used for symbol table visible
"""
GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY'

"""
generate stub func body by return type
"""
RETURN_STATEMENTS = {
'graphStatus': ' return GRAPH_SUCCESS;',
'Status': ' return SUCCESS;',
'Graph': ' return Graph();',
'Graph&': ' return *this;',
'Format': ' return Format();',
'Format&': ' return *this;',
'Shape': ' return Shape();',
'Shape&': ' return *this;',
'TensorDesc': ' return TensorDesc();',
'TensorDesc&': ' return *this;',
'Tensor': ' return Tensor();',
'Tensor&': ' return *this;',
'Operator': ' return Operator();',
'Operator&': ' return *this;',
'Ptr': ' return nullptr;',
'std::string': ' return "";',
'std::string&': ' return "";',
'string': ' return "";',
'int': ' return 0;',
'DataType': ' return DT_FLOAT;',
'InferenceContextPtr': ' return nullptr;',
'SubgraphBuilder': ' return nullptr;',
'OperatorImplPtr': ' return nullptr;',
'OutHandler': ' return nullptr;',
'std::vector<std::string>': ' return {};',
'std::vector<int64_t>': ' return {};',
'std::map': ' return {};',
'uint32_t': ' return 0;',
'int64_t': ' return 0;',
'uint64_t': ' return 0;',
'size_t': ' return 0;',
'float': ' return 0.0f;',
'bool': ' return false;',
}

"""
max code len per line in hua_wei software programming specifications
"""
max_code_len_per_line = 100

"""
white_list_for_debug, include_dir_key_words is to
determines which header files to generate cc files from
when DEBUG on
"""
white_list_for_debug = ["operator.h", "tensor.h",
"graph.h", "operator_factory.h",
"ge_ir_build.h"]
include_dir_key_words = ["ge", "graph"]
DEBUG = True


def need_generate_func(func_line):
"""
:param func_line:
:return:
"""
if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \
or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"):
return False
return True


def file_endswith_white_list_suffix(file):
"""
:param file:
:return:
"""
if DEBUG:
for suffix in white_list_for_debug:
if file.endswith(suffix):
return True
return False
else:
return True


"""
belows are patterns used for analyse .h file
"""
# pattern function
pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after
([a-zA-Z~_] # void int likely
.*
[)] #we find )
(?!.*{) # we do not want the case int abc() const { return 1;}
.*)
(;.*) #we want to find ; and after for we will replace these later
\n$
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

# pattern comment
pattern_comment = re.compile(r'^\s*//')
pattern_comment_2_start = re.compile(r'^\s*/[*]')
pattern_comment_2_end = re.compile(r'[*]/\s*$')
# pattern define
pattern_define = re.compile(r'^\s*#define')
pattern_define_return = re.compile(r'\\\s*$')
# blank line
pattern_blank_line = re.compile(r'^\s*$')
# virtual,explicit,friend,static
pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)')
# lead space
pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]')
# functions will have patterns such as func ( or func(
# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist
# format like :"operator = ()"
pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]')
# template
pattern_template = re.compile(r'^\s*template')
pattern_template_end = re.compile(r'>\s*$')
# namespace
pattern_namespace = re.compile(r'namespace.*{')
# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with
pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR)
# {}
pattern_start = re.compile('{')
pattern_end = re.compile('}')

line_index = 0


class H2CC(object):
def __init__(self, input_file, output_file, shared_includes_content):
"""
:param input_file:
:param output_file:
:param shared_includes_content:
"""
self.input_file = input_file
self.output_file = output_file
self.shared_includes_content = shared_includes_content
self.line_index = 0
self.input_fd = open(self.input_file, 'r')
self.input_content = self.input_fd.readlines()
self.output_fd = open(self.output_file, 'w')

# The state may be normal_now(in the middle of {}),class_now,namespace_now
self.stack = []
self.stack_class = []
self.stack_template = []
# record funcs generated by h2cc func
self.func_list_exist = []

def __del__(self):
self.input_fd.close()
self.output_fd.close()
del self.stack
del self.stack_class
del self.stack_template
del self.func_list_exist

def just_skip(self):
# skip blank line or comment
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
self.input_content[self.line_index]): # /n or comment using //
self.line_index += 1
if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /*
while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */
self.line_index += 1
self.line_index += 1
# skip define
if pattern_define.search(self.input_content[self.line_index]):
while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search(
self.input_content[self.line_index]):
self.line_index += 1
self.line_index += 1

def write_inc_content(self):
for shared_include_content in self.shared_includes_content:
self.output_fd.write(shared_include_content)

def h2cc(self):
"""
:return:
"""
logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file)
global pattern_comment
global pattern_comment_2_start
global pattern_comment_2_end
global pattern_blank_line
global pattern_func
global pattern_keyword
global pattern_leading_space
global pattern_func_name
global pattern_template
global pattern_template_end
global pattern_namespace
global pattern_class
global pattern_start
global pattern_end
global line_index
# write inc content
self.write_inc_content()
# core processing cycle, process the input .h file by line
while self.line_index < len(self.input_content):
# handle comment and blank line
self.just_skip()

# match namespace
self.handle_namespace()

# match template
template_string = self.handle_template()
# match class
line = self.input_content[self.line_index]
match_class = pattern_class.search(line)
match_start = pattern_start.search(line)
handle_class_result = self.handle_class(template_string, line, match_start, match_class)
if handle_class_result == "continue":
continue

# match "}"
handle_stack_result = self.handle_stack(match_start)
if handle_stack_result == "continue":
continue
# handle func
handle_func1_result, line, start_i = self.handle_func1(line)
if handle_func1_result == "continue":
continue

# here means func is found
# delete key word
line = pattern_keyword.sub('', line)
logging.info("line[%s]", line)

# Class member function
# if friend we will not add class name
friend_match = re.search('friend ', line)
if len(self.stack_class) > 0 and not friend_match:
line, func_name = self.handle_class_member_func(line, template_string)
# Normal functions
else:
line, func_name = self.handle_normal_func(line, template_string)

need_generate = need_generate_func(line)
# func body
line += self.implement_function(line)
# comment
line = self.gen_comment(start_i) + line
# write to out file
self.write_func_content(line, func_name, need_generate)
# next loop
self.line_index += 1

logging.info('Added %s functions', len(self.func_list_exist))
logging.info('Successfully converted,please see ' + self.output_file)

def handle_func1(self, line):
"""
:param line:
:return:
"""
find1 = re.search('[(]', line)
if not find1:
self.line_index += 1
return "continue", line, None
find2 = re.search('[)]', line)
start_i = self.line_index
space_match = pattern_leading_space.search(line)
# deal with
# int abc(int a,
# int b)
if find1 and (not find2):
self.line_index += 1
line2 = self.input_content[self.line_index]
if space_match:
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2
while self.line_index < len(self.input_content) and (not re.search('[)]', line2)):
self.line_index += 1
line2 = self.input_content[self.line_index]
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2

match_start = pattern_start.search(self.input_content[self.line_index])
match_end = pattern_end.search(self.input_content[self.line_index])
if match_start: # like ) { or ) {} int the last line
if not match_end:
self.stack.append('normal_now')
ii = start_i
while ii <= self.line_index:
ii += 1
self.line_index += 1
return "continue", line, start_i
logging.info("line[%s]", line)
# ' int abc();'->'int abc()'
(line, match) = pattern_func.subn(r'\2\n', line)
logging.info("line[%s]", line)
# deal with case:
# 'int \n abc(int a, int b)'
if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]):
line = self.input_content[start_i - 1] + line
line = line.lstrip()
if not match:
self.line_index += 1
return "continue", line, start_i
return "pass", line, start_i

def handle_stack(self, match_start):
"""
:param match_start:
:return:
"""
line = self.input_content[self.line_index]
match_end = pattern_end.search(line)
if match_start:
self.stack.append('normal_now')
if match_end:
top_status = self.stack.pop()
if top_status == 'namespace_now':
self.output_fd.write(line + '\n')
elif top_status == 'class_now':
self.stack_class.pop()
self.stack_template.pop()
if match_start or match_end:
self.line_index += 1
return "continue"

if len(self.stack) > 0 and self.stack[-1] == 'normal_now':
self.line_index += 1
return "continue"
return "pass"

def handle_class(self, template_string, line, match_start, match_class):
"""
:param template_string:
:param line:
:param match_start:
:param match_class:
:return:
"""
if match_class: # we face a class
self.stack_template.append(template_string)
self.stack.append('class_now')
class_name = match_class.group(3)

# class template specializations: class A<u,Node<u> >
if '<' in class_name:
k = line.index('<')
fit = 1
for ii in range(k + 1, len(line)):
if line[ii] == '<':
fit += 1
if line[ii] == '>':
fit -= 1
if fit == 0:
break
class_name += line[k + 1:ii + 1]
logging.info('class_name[%s]', class_name)
self.stack_class.append(class_name)
while not match_start:
self.line_index += 1
line = self.input_content[self.line_index]
match_start = pattern_start.search(line)
self.line_index += 1
return "continue"
return "pass"

def handle_template(self):
line = self.input_content[self.line_index]
match_template = pattern_template.search(line)
template_string = ''
if match_template:
match_template_end = pattern_template_end.search(line)
template_string = line
while not match_template_end:
self.line_index += 1
line = self.input_content[self.line_index]
template_string += line
match_template_end = pattern_template_end.search(line)
self.line_index += 1
return template_string

def handle_namespace(self):
line = self.input_content[self.line_index]
match_namespace = pattern_namespace.search(line)
if match_namespace: # we face namespace
self.output_fd.write(line + '\n')
self.stack.append('namespace_now')
self.line_index += 1

def handle_normal_func(self, line, template_string):
template_line = ''
self.stack_template.append(template_string)
if self.stack_template[-1] != '':
template_line = re.sub(r'\s*template', 'template', self.stack_template[-1])
# change '< class T = a, class U = A(3)>' to '<class T, class U>'
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = re.sub(r'\s*=.*,', ',', line)
line = re.sub(r'\s*=.*\)', ')', line)
line = template_line + line
self.stack_template.pop()
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name

def handle_class_member_func(self, line, template_string):
template_line = ''
x = ''
if template_string != '':
template_string = re.sub(r'\s*template', 'template', template_string)
template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string)
template_string = re.sub(r'\s*=.*,', ',', template_string)
template_string = re.sub(r'\s*=.*', '', template_string)
if self.stack_template[-1] != '':
if not (re.search(r'<\s*>', stack_template[-1])):
template_line = re.sub(r'^\s*template', 'template', stack_template[-1])
if not (re.search(r'<.*>', self.stack_class[-1])):
# for x we get like template<class T, typename U> -> <T,U>
x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U>
x = re.sub(r'\n', '', x)
x = re.sub(r'\s*=.*,', ',', x)
x = re.sub(r'\s*=.*\>', '>', x)
x = x.rstrip() # remove \n
x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '',
x) # remove class,typename -> <T, U>
x = re.sub(r'<\s+', '<', x)
x = re.sub(r'\s+>', '>', x)
x = re.sub(r'\s+,', ',', x)
x = re.sub(r',\s+', ', ', x)
line = re.sub(r'\s*=\s+0', '', line)
line = re.sub(r'\s*=\s+.*,', ',', line)
line = re.sub(r'\s*=\s+.*\)', ')', line)
logging.info("x[%s]\nline[%s]", x, line)
# if the function is long, void ABC::foo()
# breaks into two lines void ABC::\n foo()
temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1)
if len(temp_line) > max_code_len_per_line:
line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1)
else:
line = temp_line
logging.info("line[%s]", line)
# add template as the above if there is one
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = template_line + template_string + line
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name

def write_func_content(self, content, func_name, need_generate):
if not (func_name in self.func_list_exist) and need_generate:
self.output_fd.write(content)
self.func_list_exist.append(func_name)
logging.info('add func:[%s]', func_name)

def gen_comment(self, start_i):
comment_line = ''
# Function comments are on top of function declarations, copy them over
k = start_i - 1 # one line before this func start
if pattern_template.search(self.input_content[k]):
k -= 1
if pattern_comment_2_end.search(self.input_content[k]):
comment_line = self.input_content[k].lstrip()
while not pattern_comment_2_start.search(self.input_content[k]):
k -= 1
comment_line = self.input_content[k].lstrip() + comment_line
else:
for j in range(k, 0, -1):
c_line = self.input_content[j]
if pattern_comment.search(c_line):
c_line = re.sub(r'\s*//', '//', c_line)
comment_line = c_line + comment_line
else:
break
return comment_line

@staticmethod
def implement_function(func):
function_def = ''
function_def += '{\n'

all_items = func.split()
start = 0
return_type = all_items[start]
if return_type == "const":
start += 1
return_type = all_items[start]
if return_type.startswith(('std::map', 'std::set', 'std::vector')):
return_type = "std::map"
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
return_type = "Ptr"
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
return_type += "&"
if RETURN_STATEMENTS.__contains__(return_type):
function_def += RETURN_STATEMENTS[return_type]
else:
logging.warning("Unhandled return type[%s]", return_type)

function_def += '\n'
function_def += '}\n'
function_def += '\n'
return function_def


def collect_header_files(path):
"""
:param path:
:return:
"""
header_files = []
shared_includes_content = []
for root, dirs, files in os.walk(path):
files.sort()
for file in files:
if file.find("git") >= 0:
continue
if not file.endswith('.h'):
continue
file_path = os.path.join(root, file)
file_path = file_path.replace('\\', '/')
header_files.append(file_path)
include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:])
shared_includes_content.append(include_str)
return header_files, shared_includes_content


def generate_stub_file(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
target_header_files, shared_includes_content = collect_header_files(inc_dir)
for header_file in target_header_files:
if not file_endswith_white_list_suffix(header_file):
continue
cc_file = re.sub('.h*$', '.cc', header_file)
h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content)
h_2_cc.h2cc()


def gen_code(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
if not inc_dir.endswith('/'):
inc_dir += '/'
if not out_cc_dir.endswith('/'):
out_cc_dir += '/'
for include_dir_key_word in include_dir_key_words:
generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir)


if __name__ == '__main__':
inc_dir = sys.argv[1]
out_cc_dir = sys.argv[2]
gen_code(inc_dir, out_cc_dir)

+ 14
- 10
src/common/graph/tensor.cc View File

@@ -178,16 +178,18 @@ int64_t Shape::GetShapeSize() const {
return 0;
}

TensorDesc::TensorDesc() { impl = ComGraphMakeShared<TensorDescImpl>(); }
TensorDesc::TensorDesc() {
impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665
}

TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) {
impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt);
impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt); // lint !e665
SetRealDimCnt(shape.GetDimNum());
}

TensorDesc::TensorDesc(const TensorDesc &desc) {
// Copy
impl = ComGraphMakeShared<TensorDescImpl>();
impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665
if (desc.impl != nullptr && impl != nullptr) {
*impl = *desc.impl;
}
@@ -358,7 +360,9 @@ void TensorDesc::SetName(const std::string &name) {

Tensor::Tensor() { impl = ComGraphMakeShared<TensorImpl>(); }

Tensor::Tensor(const TensorDesc &tensor_desc) { impl = ComGraphMakeShared<TensorImpl>(tensor_desc); }
Tensor::Tensor(const TensorDesc &tensor_desc) {
impl = ComGraphMakeShared<TensorImpl>(tensor_desc); // lint !e665
}

Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) {
uint64_t shape_size = tensor_desc.GetShape().GetShapeSize();
@@ -380,7 +384,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data)
}
}
}
impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data);
impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data); // lint !e665
}

Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) {
@@ -402,7 +406,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size)
}
}

impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size);
impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); // lint !e665
}

Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) {
@@ -425,7 +429,7 @@ Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) {
}
}
}
impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data));
impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data)); // lint !e665
}

TensorDesc Tensor::GetTensorDesc() const {
@@ -639,7 +643,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_
GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) {
GeTensorPtr ge_tensor;
if (tensor.impl != nullptr) {
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone());
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone()); // lint !e665
}
return ge_tensor;
}
@@ -655,7 +659,7 @@ Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) {
ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) {
GeTensorPtr ge_tensor;
if (tensor.impl != nullptr) {
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor);
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665
}
return ge_tensor;
}
@@ -663,7 +667,7 @@ ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) {
GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) {
GeTensorPtr ge_tensor;
if (tensor.impl != nullptr) {
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor);
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665
}
return ge_tensor;
}


+ 202
- 24
src/common/graph/utils/graph_utils.cc View File

@@ -38,6 +38,7 @@
#include "utils/ge_ir_utils.h"
#include "utils/node_utils.h"
#include "debug/ge_op_types.h"
#include "external/ge/ge_api_types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
@@ -410,8 +411,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTra
/// @return graphStatus
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) {
GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) {
GE_CHECK_NOTNULL(src);
GE_CHECK_NOTNULL(insert_node);

@@ -570,7 +571,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons
static int max_dumpfile_num = 0;
if (max_dumpfile_num == 0) {
string opt = "0";
(void)GetContext().GetOption("ge.maxDumpFileNum", opt);
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt);
max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue);
}
if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) {
@@ -670,7 +671,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText
if (maxDumpFileSize == 0) {
string opt = "0";
// Can not check return value
(void)GetContext().GetOption("ge.maxDumpFileSize", opt);
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt);
maxDumpFileSize = atol(opt.c_str());
}
if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) {
@@ -740,7 +741,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn
static int max_dumpfile_num = 0;
if (max_dumpfile_num == 0) {
string opt = "0";
(void)GetContext().GetOption("ge.maxDumpFileNum", opt);
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt);
max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue);
}
if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) {
@@ -920,7 +921,7 @@ graphStatus RelinkDataIO(const NodePtr &node, const std::vector<int> &io_map, In
InNodesToOut GetFullConnectIONodes(const NodePtr &node) {
InNodesToOut in_nodes_to_out;
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Node is nullptr,node is %s", node->GetName().c_str());
GELOGE(GRAPH_FAILED, "Node is nullptr");
return in_nodes_to_out;
}
auto in_nodes_list = node->GetInNodes();
@@ -1308,6 +1309,36 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCt
return GRAPH_SUCCESS;
}

///
/// Copy all in-data edges from `src_node` to `dst_node`.
/// @param src_node
/// @param dst_node
/// @return
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node,
NodePtr &dst_node) {
if ((src_node == nullptr) || (dst_node == nullptr)) {
GELOGE(GRAPH_FAILED, "Parameter is nullptr");
return GRAPH_PARAM_INVALID;
}
auto src_data_in_nodes = src_node->GetInDataNodes();
if (src_data_in_nodes.empty()) {
return GRAPH_SUCCESS;
}
for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) {
auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx());
auto ret =
GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s",
in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(),
src_node->GetName().c_str(), dst_node->GetName().c_str());
return ret;
}
}
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph,
const NodePtr &node) {
if (graph->AddInputNode(node) == nullptr) {
@@ -1328,6 +1359,153 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindR
return result;
}

///
/// Make a copy of ComputeGraph.
/// @param graph: original graph.
/// @param prefix: node name prefix of new graph.
/// @return ComputeGraphPtr
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr
GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, std::vector<NodePtr> &input_nodes,
std::vector<NodePtr> &output_nodes) {
GE_CHK_BOOL_EXEC(graph != nullptr, return nullptr, "Original graph is null");
ComputeGraphPtr new_graph = ComGraphMakeShared<ComputeGraph>(graph->GetName());
GE_CHK_BOOL_EXEC(new_graph != nullptr, return nullptr, "Create new graph failed");

std::unordered_map<std::string, NodePtr> all_new_nodes;
for (const auto &n : graph->GetDirectNode()) {
OpDescPtr op_desc = AttrUtils::CopyOpDesc(n->GetOpDesc());
GE_CHK_BOOL_EXEC(op_desc != nullptr, return nullptr, "Create new node failed");

if (CopyTensorAttrs(op_desc, n) != GRAPH_SUCCESS) {
return nullptr;
}

op_desc->SetName(prefix + n->GetName());
NodePtr node = new_graph->AddNode(op_desc);
GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str());
all_new_nodes[node->GetName()] = node;

if (node->GetType() == DATA) {
input_nodes.emplace_back(node);
} else if (node->GetType() == NETOUTPUT) {
output_nodes.emplace_back(node);
}
}

for (const auto &n : graph->GetDirectNode()) {
if (RelinkGraphEdges(n, prefix, all_new_nodes) != GRAPH_SUCCESS) {
return nullptr;
}
}

return new_graph;
}

///
/// Copy tensor attribute to new node.
/// @param [in] dst_node: cloned node.
/// @param [in] src_node: original node.
/// @return success: GRAPH_SUCESS
///
graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node) {
if (dst_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Input param dst node not valid");
return GRAPH_FAILED;
}
if (src_node == nullptr || src_node->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "Input param src node not valid");
return GRAPH_FAILED;
}

const auto &src_desc = src_node->GetOpDesc();
dst_desc->CopyAttrsFrom(*src_desc);

for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) {
auto input_desc = dst_desc->MutableInputDesc(i);
if (input_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Param dst node not valid");
return GRAPH_FAILED;
}
input_desc->CopyAttrsFrom(src_desc->GetInputDesc(i));
}

for (uint32_t i = 0; i < src_node->GetAllOutDataAnchorsSize(); ++i) {
auto output_desc = dst_desc->MutableOutputDesc(i);
if (output_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Param dst node not valid");
return GRAPH_FAILED;
}
output_desc->CopyAttrsFrom(src_desc->GetOutputDesc(i));
}

return GRAPH_SUCCESS;
}

///
/// Relink all edges for cloned ComputeGraph.
/// @param [in] node: original node.
/// @param [in] prefix: node name prefix of new node.
/// @param [in] all_nodes: all nodes in new graph.
/// @return success: GRAPH_SUCESS
///
graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &prefix,
const std::unordered_map<string, NodePtr> &all_nodes) {
if (node == nullptr || node->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "Input node not valid");
return GRAPH_FAILED;
}

auto it = all_nodes.find(prefix + node->GetName());
if (it == all_nodes.end()) {
GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str());
return GRAPH_FAILED;
}
const auto &new_node = it->second;

for (const auto &in_anchor : node->GetAllInDataAnchors()) {
GE_CHK_BOOL_EXEC(in_anchor != nullptr, return GRAPH_FAILED, "In data anchor is null");
const auto &out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
GELOGW("Peer out anchor is null: %s", node->GetName().c_str());
continue;
}
GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null");

it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName());
if (it == all_nodes.end()) {
GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str());
return GRAPH_FAILED;
}
const auto &new_out_node = it->second;

auto rslt =
GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInAnchor(in_anchor->GetIdx()));
GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]",
new_out_node->GetName().c_str(), new_node->GetName().c_str());
}

if (node->GetInControlAnchor() != nullptr) {
for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) {
GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str());
GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null");

it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName());
if (it == all_nodes.end()) {
GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str());
return GRAPH_FAILED;
}
const auto &new_out_node = it->second;

auto rslt = GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInControlAnchor());
GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]",
new_out_node->GetName().c_str(), new_node->GetName().c_str());
}
}

return GRAPH_SUCCESS;
}

///
/// Get reference-mapping of all data_anchors in graph
/// @param [in] graph
@@ -1339,7 +1517,7 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol) {
GE_CHECK_NOTNULL(graph);
for (auto &node : graph->GetAllNodes()) {
for (const auto &node : graph->GetAllNodes()) {
// in_data_anchor
if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) {
GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str());
@@ -1396,16 +1574,16 @@ graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node,
return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol);
}

std::string type = node->GetType();
const std::string &type = node->GetType();
if ((type == MERGE) || (type == STREAMMERGE)) {
return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol);
}

for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn);
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
std::string symbol = cur_node_info.ToString();
const std::string &symbol = cur_node_info.ToString();
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str());
symbol_to_anchors[symbol] = {cur_node_info};
anchor_to_symbol[symbol] = symbol;
@@ -1432,7 +1610,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol) {
GE_CHECK_NOTNULL(node);
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut);
if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) {
continue;
@@ -1446,7 +1624,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node,
return GRAPH_FAILED;
}
} else {
std::string symbol = cur_node_info.ToString();
const std::string &symbol = cur_node_info.ToString();
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str());
symbol_to_anchors.emplace(std::make_pair(symbol, std::list<NodeIndexIO>{cur_node_info}));
anchor_to_symbol.emplace(std::make_pair(symbol, symbol));
@@ -1506,7 +1684,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,
GE_CHECK_NOTNULL(node);
std::vector<NodeIndexIO> exist_node_infos;
std::vector<NodeIndexIO> cur_node_infos;
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
std::string next_name;
@@ -1529,10 +1707,10 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,

size_t anchor_nums = 0;
NodeIndexIO max_node_index_io(nullptr, 0, kOut);
for (auto &temp_node_info : exist_node_infos) {
for (const auto &temp_node_info : exist_node_infos) {
auto iter1 = anchor_to_symbol.find(temp_node_info.ToString());
if (iter1 != anchor_to_symbol.end()) {
std::string temp_symbol = iter1->second;
const std::string &temp_symbol = iter1->second;
auto iter2 = symbol_to_anchors.find(temp_symbol);
if (iter2 != symbol_to_anchors.end()) {
if (iter2->second.size() > anchor_nums) {
@@ -1544,7 +1722,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,
}

std::string symbol;
for (auto &temp_node_info : exist_node_infos) {
for (const auto &temp_node_info : exist_node_infos) {
if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) !=
GRAPH_SUCCESS) ||
symbol.empty()) {
@@ -1556,7 +1734,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,

auto iter = symbol_to_anchors.find(symbol);
if (iter != symbol_to_anchors.end()) {
for (auto &temp_node_info : cur_node_infos) {
for (const auto &temp_node_info : cur_node_infos) {
GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str());
iter->second.emplace_back(temp_node_info);
anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol));
@@ -1584,7 +1762,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node,

OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);

@@ -1627,8 +1805,8 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node,
graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol) {
std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()];
std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()];
const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()];
const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()];
if (symbol1 == symbol2) {
symbol = symbol1;
GELOGI("no need to union.");
@@ -1684,7 +1862,7 @@ graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const
return GRAPH_FAILED;
}

std::string symbol = iter1->second;
const std::string &symbol = iter1->second;
auto iter2 = symbol_to_anchors.find(symbol);
if (iter2 == symbol_to_anchors.end()) {
GE_LOGE("symbol %s not found.", symbol.c_str());
@@ -1712,7 +1890,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t

// pass-through op
NodePtr node = out_data_anchor->GetOwnerNode();
std::string type = node->GetType();
const std::string &type = node->GetType();
const std::set<std::string> pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE};
if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) {
reuse_in_index = output_index;
@@ -1755,7 +1933,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t
uint32_t reuse_input_index = 0;
if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) {
reuse_in_index = static_cast<int32_t>(reuse_input_index);
GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index,
GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index,
reuse_in_index);
return true;
}
@@ -2297,7 +2475,7 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &
return;
}

std::string name = node->GetName() + "_RetVal";
std::string name = node->GetName() + "_RetVal_" + std::to_string(index);
OpDescPtr ret_val_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, FRAMEWORKOP));
if (ret_val_desc == nullptr) {
error_code = GRAPH_FAILED;


+ 217
- 18
src/common/graph/utils/node_utils.cc View File

@@ -295,16 +295,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer
if (op_desc == nullptr) {
return GRAPH_FAILED;
}
bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
if (is_unknown_graph) {
return GRAPH_SUCCESS;
}
for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx());
ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast<uint32_t>(output_tensor.GetShape().GetDims().size()));
output_tensor.SetOriginShape(output_tensor.GetShape());
output_tensor.SetOriginDataType(output_tensor.GetDataType());
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
output_tensor->SetOriginShape(output_tensor->GetShape());
output_tensor->SetOriginDataType(output_tensor->GetDataType());

GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str());
(void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor);
node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
@@ -316,17 +321,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer
continue;
}
GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(),
output_tensor.GetDataType(), output_tensor.GetOriginDataType());
peer_input_desc->SetShape(output_tensor.GetShape());
peer_input_desc->SetOriginShape(output_tensor.GetOriginShape());
peer_input_desc->SetDataType(output_tensor.GetDataType());
peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType());
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(),
output_tensor->GetDataType(), output_tensor->GetOriginDataType());
peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
peer_input_desc->SetShape(output_tensor->GetShape());
peer_input_desc->SetDataType(output_tensor->GetDataType());
peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)output_tensor.GetShapeRange(shape_range);
(void)output_tensor->GetShapeRange(shape_range);
peer_input_desc->SetShapeRange(shape_range);
ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
static_cast<uint32_t>(output_tensor.GetShape().GetDims().size()));
static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
@@ -334,6 +339,50 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer
}
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node,
uint32_t index) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
return GRAPH_FAILED;
}

GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
OpDescPtr op_desc = node->op_;
for (size_t i = op_desc->GetInputsSize(); i < index; ++i) {
if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Add input desc failed");
return GRAPH_FAILED;
}

auto anchor = ComGraphMakeShared<InDataAnchor>(node, i);
if (anchor == nullptr) {
GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed.");
return GRAPH_FAILED;
}
node->in_data_anchors_.push_back(anchor);
}

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node,
uint32_t index) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
return GRAPH_FAILED;
}

OpDescPtr op_desc = node->op_;
op_desc->RemoveInputDesc(index);

while (node->in_data_anchors_.size() > index) {
node->in_data_anchors_.pop_back();
}

return GRAPH_SUCCESS;
}

bool NodeUtils::IsInNodesEmpty(const Node &node) {
for (const auto &in_anchor : node.in_data_anchors_) {
if (in_anchor != nullptr) {
@@ -401,10 +450,13 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const
graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
auto desc = node.GetOpDesc();
GE_CHECK_NOTNULL(desc);

// check self
is_unknow = OpShapeIsUnknown(desc);
if (is_unknow) {
return GRAPH_SUCCESS;
}
auto sub_graph_names = desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
is_unknow = OpShapeIsUnknown(desc);
return GRAPH_SUCCESS;
} else {
auto owner_graph = node.GetOwnerComputeGraph();
@@ -440,6 +492,7 @@ std::string NodeUtils::GetNodeType(const Node &node) {
(void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
return type;
}

ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
auto op_desc = node.GetOpDesc();
if (op_desc == nullptr) {
@@ -492,6 +545,14 @@ bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
return false;
}
if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
bool is_unknown_shape = false;
(void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape);
if (is_unknown_shape) return false;
}

if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) &&
kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 &&
kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) {
return false;
}

@@ -513,7 +574,16 @@ bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
if (parent_op_desc == nullptr) {
return false;
}

if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
bool is_unknown_shape = false;
(void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape);
if (is_unknown_shape) return false;
}

if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) &&
kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 &&
kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) {
return false;
}

@@ -555,6 +625,53 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) {
return peer_out_anchor->GetOwnerNode();
}

///
/// @brief Check is varying_input for while node
/// @param [in] node: Data node for subgraph
/// @return bool
///
bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
if (node == nullptr) {
return false;
}
if (node->GetType() != DATA) {
return false; // not input_node for subgraph
}

const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
if (parent_node == nullptr) {
return false; // root graph
}

if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
return false; // not input_node for while subgraph
}

uint32_t index_i = 0;
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
return false;
}
bool varying_flag = true;
for (const auto &item : node->GetOutDataNodesAndAnchors()) {
if (item.first->GetType() != NETOUTPUT) {
continue;
}
OpDescPtr op_desc = item.first->GetOpDesc();
uint32_t index_o = 0;
if ((op_desc == nullptr) ||
!AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
continue; // input for while-cond subgraph
}
if (index_i != index_o) {
continue; // varying input for while-body subgraph
}
varying_flag = false;
break;
}
return varying_flag;
}

///
/// @brief Get subgraph input is constant.
/// @param [in] node
@@ -637,4 +754,86 @@ Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {

return GRAPH_SUCCESS;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
vector<NodePtr> in_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
return in_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
int parent_index = -1;
if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
(void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
if (parent_index == index) {
in_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
}
return in_data_node_vec;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
vector<NodePtr> out_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
return out_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
out_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
return out_data_node_vec;
}

NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) {
if (node.GetInDataAnchor(index) == nullptr) {
return nullptr;
}
if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
return nullptr;
}
return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
}

vector<NodePtr> NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) {
vector<NodePtr> out_data_nodes;
auto out_data_anchor = node.GetOutDataAnchor(index);
if (out_data_anchor == nullptr) {
return out_data_nodes;
}
for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
if (peer_in_anchor == nullptr) {
continue;
}
if (peer_in_anchor->GetOwnerNode() == nullptr) {
continue;
}
out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
}
return out_data_nodes;
}
} // namespace ge

+ 52
- 24
src/common/graph/utils/op_desc_utils.cc View File

@@ -28,6 +28,7 @@

using std::vector;

/*lint -e512 -e737 -e752*/
namespace ge {
const char OP_DESC_QUANT_PARAMS[] = "quantize_factor";
static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1;
@@ -132,11 +133,11 @@ graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, Quantize
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) {
GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant));
return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
}

graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) {
return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant));
return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
}

GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) {
@@ -197,24 +198,33 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::
continue;
}
auto in_node = out_anchor->GetOwnerNode();
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
ret.push_back(in_node);
} else if (in_node->GetType() == DATA) {
const ComputeGraphPtr &graph = node.GetOwnerComputeGraph();
GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null");

const NodePtr &parent_node = graph->GetParentNode();
if (parent_node == nullptr) {
continue; // Root graph.
}

if (kWhileOpTypes.count(parent_node->GetType()) > 0) {
continue; // Subgraph of While cond or body.
while (true) {
if (in_node == nullptr) {
break;
}

NodePtr input_node = NodeUtils::GetParentInput(in_node);
if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) {
ret.push_back(input_node);
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
ret.push_back(in_node);
break;
} else if (in_node->GetType() == DATA) {
if (NodeUtils::IsWhileVaryingInput(in_node)) {
break;
}
in_node = NodeUtils::GetParentInput(in_node);
} else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) {
bool is_constant = false;
(void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant);
if (!is_constant) {
break;
}
// Enter node has and only has one input
if (in_node->GetInDataNodes().size() != 1) {
GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(),
in_node->GetInDataNodes().size());
break;
}
in_node = in_node->GetInDataNodes().at(0);
} else {
break;
}
}
}
@@ -245,7 +255,7 @@ size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) {
continue;
}
}
return input_num;
return input_num; // lint !e712
} else {
GE_IF_BOOL_EXEC(
node.GetInDataNodes().size() < GetConstInputs(node).size(),
@@ -350,7 +360,7 @@ bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) {
bool ret = false;
if (index < node.GetAllInDataAnchors().size()) {
if (NodeUtils::IsAnchorStatusSet(node)) {
ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA);
ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712
} else {
for (const auto &anchor : node.GetAllInDataAnchors()) {
if (anchor->GetIdx() != static_cast<int>(index)) {
@@ -435,10 +445,27 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) {
vector<GeTensorPtr> ret;
GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return ret, "node.GetOpDesc is nullptr!");
auto op_desc = node.GetOpDesc();
GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!");
// Place holder operator, try to get the weight from parent node
// when parent node is const operator
if (node.GetType() == PLACEHOLDER) {
std::string parent_op;
(void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op);
// This if judgment is necessary because the current subgraph optimization is multithreaded
// and the parent node of the PLD operation should be a stable type, such as const
if (parent_op == CONSTANT || parent_op == CONSTANTOP) {
NodePtr parent_node = nullptr;
parent_node = op_desc->TryGetExtAttr("parentNode", parent_node);
if (parent_node != nullptr) {
op_desc = parent_node->GetOpDesc();
GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str());
}
}
}
// Const operator, take the weight directly
if (node.GetOpDesc()->GetType() == CONSTANT || (node.GetOpDesc()->GetType() == CONSTANTOP)) {
auto weight = MutableWeights(node.GetOpDesc());
if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) {
auto weight = MutableWeights(op_desc);
if (weight == nullptr) {
GELOGI("const op has no weight, op name:%s", node.GetName().c_str());
return ret;
@@ -733,3 +760,4 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgr
return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name);
}
} // namespace ge
/*lint +e512 +e737 +e752*/

+ 6
- 2
src/common/graph/utils/tensor_utils.cc View File

@@ -19,6 +19,7 @@

#include "debug/ge_log.h"
#include "framework/common/debug/ge_log.h"
#include "common/util/error_manager/error_manager.h"
#include "graph/ge_tensor.h"
#include "graph/types.h"
#include "graph/utils/type_utils.h"
@@ -105,7 +106,10 @@ static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_
element_cnt = 1;
for (int64_t dim : dims) {
if (CheckMultiplyOverflowInt64(element_cnt, dim)) {
GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim);
ErrorManager::GetInstance().ATCReportErrMessage(
"E19013", {"function", "var1", "var2"},
{"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)});
GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim);
return GRAPH_FAILED;
}
element_cnt *= dim;
@@ -273,7 +277,6 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format
case FORMAT_FRACTAL_Z:
graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt);
break;
case FORMAT_NC1HWC0_C04:
case FORMAT_FRACTAL_NZ:
case FORMAT_FRACTAL_ZZ:
case FORMAT_NDHWC:
@@ -285,6 +288,7 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format
case FORMAT_NDC1HWC0:
case FORMAT_FRACTAL_Z_C04:
case FORMAT_FRACTAL_ZN_LSTM:
case FORMAT_NC1HWC0_C04:
graph_status = CalcElementCntByDims(dims, element_cnt);
break;
default:


+ 2
- 1
src/common/graph/utils/type_utils.cc View File

@@ -147,7 +147,8 @@ static const std::map<std::string, Format> kStringToFormatMap = {
{"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM},
{"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G},
{"FORMAT_RESERVED", FORMAT_RESERVED},
{"ALL", FORMAT_ALL}};
{"ALL", FORMAT_ALL},
{"NULL", FORMAT_NULL}};

static const std::map<DataType, std::string> kDataTypeToStringMap = {
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set.


+ 26
- 4
src/ge/CMakeLists.txt View File

@@ -60,6 +60,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"common/formats/formats.cc"
"common/formats/utils/formats_trans_utils.cc"
"common/fp16_t.cc"
"common/ge/op_tiling_manager.cc"
"common/ge/plugin_manager.cc"
"common/helper/model_cache_helper.cc"
"common/profiling/profiling_manager.cc"
@@ -94,14 +95,25 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc"
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc"
"graph/load/new_model_manager/task_info/task_info.cc"
"graph/load/output/output.cc"
"graph/manager/*.cc"
"graph/manager/graph_context.cc"
"graph/manager/graph_manager.cc"
"graph/manager/graph_manager_utils.cc"
"graph/manager/graph_mem_allocator.cc"
"graph/manager/graph_caching_allocator.cc"
"graph/manager/graph_var_manager.cc"
"graph/manager/model_manager/event_manager.cc"
"graph/manager/trans_var_data_utils.cc"
"graph/manager/util/debug.cc"
"graph/manager/util/hcom_util.cc"
"graph/manager/util/rt_context_util.cc"
"graph/manager/util/variable_accelerate_ctrl.cc"
"graph/manager/model_manager/event_manager.cc"
"graph/manager/util/debug.cc"
"graph/manager/util/hcom_util.cc"
"graph/manager/util/rt_context_util.cc"
"graph/manager/util/variable_accelerate_ctrl.cc"
"graph/optimize/graph_optimize.cc"
"graph/optimize/mem_rw_conflict_optimize.cc"
"graph/optimize/optimizer/allreduce_fusion_pass.cc"
"graph/optimize/summary_optimize.cc"
"graph/partition/dynamic_shape_partition.cc"
@@ -159,8 +171,11 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"hybrid/node_executor/aicpu/aicpu_ext_info.cc"
"hybrid/node_executor/aicpu/aicpu_node_executor.cc"
"hybrid/node_executor/compiledsubgraph/known_node_executor.cc"
"hybrid/node_executor/controlop/control_op_executor.cc"
"hybrid/node_executor/hccl/hccl_node_executor.cc"
"hybrid/node_executor/hostcpu/ge_local_node_executor.cc"
"hybrid/node_executor/node_executor.cc"
"hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc"
"hybrid/node_executor/task_context.cc"
"init/gelib.cc"
"model/ge_model.cc"
@@ -204,6 +219,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"common/formats/formats.cc"
"common/formats/utils/formats_trans_utils.cc"
"common/fp16_t.cc"
"common/ge/op_tiling_manager.cc"
"common/ge/plugin_manager.cc"
"common/helper/model_cache_helper.cc"
"common/profiling/profiling_manager.cc"
@@ -236,13 +252,19 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc"
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc"
"graph/load/new_model_manager/task_info/task_info.cc"
"graph/load/output/output.cc"
"graph/manager/*.cc"
"graph/manager/graph_caching_allocator.cc"
"graph/manager/graph_context.cc"
"graph/manager/graph_manager.cc"
"graph/manager/graph_manager_utils.cc"
"graph/manager/graph_mem_allocator.cc"
"graph/manager/trans_var_data_utils.cc"
"graph/manager/graph_var_manager.cc"
"graph/manager/model_manager/event_manager.cc"
"graph/manager/util/debug.cc"
"graph/manager/util/rt_context_util.cc"
"graph/manager/util/variable_accelerate_ctrl.cc"
"graph/optimize/graph_optimize.cc"
"graph/optimize/mem_rw_conflict_optimize.cc"
"graph/optimize/summary_optimize.cc"
"graph/partition/dynamic_shape_partition.cc"
"graph/partition/engine_place.cc"


+ 14
- 43
src/ge/client/ge_api.cc View File

@@ -28,6 +28,7 @@
#include "graph/opsproto_manager.h"
#include "graph/utils/type_utils.h"
#include "graph/manager/util/rt_context_util.h"
#include "graph/common/ge_call_wrapper.h"
#include "register/op_registry.h"
#include "common/ge/tbe_plugin_manager.h"

@@ -41,8 +42,8 @@ namespace {
const int32_t kMaxStrLen = 128;
}

static bool kGeInitialized = false;
static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use
static bool g_ge_initialized = false;
static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use

namespace ge {
void GetOpsProtoPath(std::string &opsproto_path) {
@@ -61,31 +62,6 @@ void GetOpsProtoPath(std::string &opsproto_path) {
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
}

Status CheckDumpAndReuseMemory(const std::map<string, string> &options) {
const int kDecimal = 10;
auto dump_op_env = std::getenv("DUMP_OP");
int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0;
auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory");
if (disableReuseMemoryIter != options.end()) {
if (disableReuseMemoryIter->second == "0") {
GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open");
if (dump_op_flag) {
GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0");
}
} else if (disableReuseMemoryIter->second == "1") {
GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close");
} else {
GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid");
return FAILED;
}
} else {
if (dump_op_flag) {
GELOGW("Will dump incorrect op data with default reuse memory");
}
}
return SUCCESS;
}

Status CheckOptionsValid(const std::map<string, string> &options) {
// check job_id is valid
auto job_id_iter = options.find(OPTION_EXEC_JOB_ID);
@@ -96,11 +72,6 @@ Status CheckOptionsValid(const std::map<string, string> &options) {
}
}

// Check ge.exec.disableReuseMemory and env DUMP_OP
if (CheckDumpAndReuseMemory(options) != SUCCESS) {
return FAILED;
}

return SUCCESS;
}

@@ -108,7 +79,7 @@ Status CheckOptionsValid(const std::map<string, string> &options) {
Status GEInitialize(const std::map<string, string> &options) {
GELOGT(TRACE_INIT, "GEInitialize start");
// 0.check init status
if (kGeInitialized) {
if (g_ge_initialized) {
GELOGW("GEInitialize is called more than once");
return SUCCESS;
}
@@ -147,9 +118,9 @@ Status GEInitialize(const std::map<string, string> &options) {
}

// 7.check return status, return
if (!kGeInitialized) {
if (!g_ge_initialized) {
// Initialize success, first time calling initialize
kGeInitialized = true;
g_ge_initialized = true;
}

GELOGT(TRACE_STOP, "GEInitialize finished");
@@ -160,12 +131,12 @@ Status GEInitialize(const std::map<string, string> &options) {
Status GEFinalize() {
GELOGT(TRACE_INIT, "GEFinalize start");
// check init status
if (!kGeInitialized) {
if (!g_ge_initialized) {
GELOGW("GEFinalize is called before GEInitialize");
return SUCCESS;
}

std::lock_guard<std::mutex> lock(kGeReleaseMutex);
std::lock_guard<std::mutex> lock(g_ge_release_mutex);
// call Finalize
Status ret = SUCCESS;
Status middle_ret;
@@ -187,10 +158,10 @@ Status GEFinalize() {
ret = middle_ret;
}

if (kGeInitialized && ret == SUCCESS) {
if (g_ge_initialized && ret == SUCCESS) {
// Unified destruct rt_context
RtContextUtil::GetInstance().DestroyrtContexts();
kGeInitialized = false;
RtContextUtil::GetInstance().DestroyAllRtContexts();
g_ge_initialized = false;
}

GELOGT(TRACE_STOP, "GEFinalize finished");
@@ -202,7 +173,7 @@ Session::Session(const std::map<string, string> &options) {
GELOGT(TRACE_INIT, "Session Constructor start");
// check init status
sessionId_ = 0;
if (!kGeInitialized) {
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED);
return;
}
@@ -232,13 +203,13 @@ Session::Session(const std::map<string, string> &options) {
Session::~Session() {
GELOGT(TRACE_INIT, "Session Destructor start");
// 0.check init status
if (!kGeInitialized) {
if (!g_ge_initialized) {
GELOGW("GE is not yet initialized or is finalized.");
return;
}

Status ret = FAILED;
std::lock_guard<std::mutex> lock(kGeReleaseMutex);
std::lock_guard<std::mutex> lock(g_ge_release_mutex);
try {
uint64_t session_id = sessionId_;
// call DestroySession


+ 6
- 5
src/ge/common/convert/pb2json.cc View File

@@ -72,9 +72,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons
void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
if (field == nullptr || reflection == nullptr) {
return;
}
switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetMessage(message, field);
@@ -118,8 +115,12 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr

case ProtobufFieldDescriptor::TYPE_FLOAT:
char str[kSignificantDigits];
sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field));
json[field->name()] = str;
if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) {
json[field->name()] = str;
} else {
json[field->name()] = reflection->GetFloat(message, field);
}

break;

case ProtobufFieldDescriptor::TYPE_STRING:


+ 0
- 1
src/ge/common/formats/format_transfers/datatype_transfer.cc View File

@@ -29,7 +29,6 @@

namespace ge {
namespace formats {

namespace {
enum DataTypeTransMode {
kTransferWithDatatypeFloatToFloat16,


+ 0
- 1
src/ge/common/formats/format_transfers/datatype_transfer.h View File

@@ -27,7 +27,6 @@

namespace ge {
namespace formats {

struct CastArgs {
const uint8_t *data;
size_t src_data_size;


+ 0
- 1
src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc View File

@@ -179,6 +179,5 @@ Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::v
}

REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D)

} // namespace formats
} // namespace ge

+ 0
- 1
src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc View File

@@ -180,6 +180,5 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, con
}

REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE)

} // namespace formats
} // namespace ge

+ 1
- 1
src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc View File

@@ -56,7 +56,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap
dst_shape.clear();
hw_shape.clear();
auto w0 = GetCubeSizeByDataType(data_type);
auto h0 = GetCubeSizeByDataType(data_type);
int64_t h0 = kCubeSize;
switch (src_shape.size()) {
case 1:
dst_shape.push_back(Ceil(src_shape[0], w0));


+ 55
- 40
src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -19,6 +19,7 @@
#include <securec.h>
#include <memory>

#include "common/debug/log.h"
#include "common/formats/utils/formats_definitions.h"
#include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h"
@@ -107,8 +108,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {

int64_t hw = h * w;
int64_t chw = c * hw;
int64_t hwc0 = hw * c0;
int64_t nchw = n * chw;
int64_t hwc0 = hw * c0;

// horizontal fractal matrix count (N)
int64_t hf_cnt = Ceil(n, static_cast<int64_t>(kNiSize));
@@ -119,18 +120,15 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
int size = GetSizeByDataType(args.src_data_type);
int64_t dst_size = total_ele_cnt * size;
if (dst_size == 0) {
result.length = static_cast<size_t>(dst_size);
return SUCCESS;
}
GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;);

std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return OUT_OF_MEMORY;
}
return OUT_OF_MEMORY;);

for (int64_t vfi = 0; vfi < vf_cnt; vfi++) {
// vertical fractal matrix base index
@@ -156,12 +154,20 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
auto protected_size = dst_size - offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
errno_t ret;
errno_t ret = EOK;
if (need_pad_zero) {
ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size));
} else {
ret = memcpy_s(dst.get() + offset, static_cast<size_t>(protected_size), args.data + src_offset * size,
static_cast<size_t>(size));
if (protected_size < size) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
protected_size, size);
return INTERNAL_ERROR;
}
char *dst_data = reinterpret_cast<char *>(dst.get() + offset);
const char *src_data = reinterpret_cast<const char *>(args.data + src_offset * size);
for (int64_t index = 0; index < size; index++) {
*dst_data++ = *src_data++;
}
}
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset,
@@ -199,18 +205,15 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
dst_size *= dim;
}
dst_size *= data_size;
if (dst_size == 0) {
result.length = static_cast<size_t>(dst_size);
return SUCCESS;
}
GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;);

std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return OUT_OF_MEMORY;
}
return OUT_OF_MEMORY;);

for (int64_t c1i = 0; c1i < c1; c1i++) {
for (int64_t hi = 0; hi < h; hi++) {
@@ -223,14 +226,22 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n);
errno_t ret;
errno_t ret = EOK;
if (pad_zero) {
ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0,
static_cast<size_t>(data_size));
} else {
if (protected_size < data_size) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
protected_size, data_size);
return INTERNAL_ERROR;
}
int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i;
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size),
args.data + src_idx * data_size, static_cast<size_t>(data_size));
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset);
const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size);
for (int64_t index = 0; index < data_size; index++) {
*dst_data++ = *src_data++;
}
}
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
@@ -269,18 +280,15 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
dst_size *= dim;
}
dst_size *= data_size;
if (dst_size == 0) {
result.length = static_cast<size_t>(dst_size);
return SUCCESS;
}
GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;);

std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return OUT_OF_MEMORY;
}
return OUT_OF_MEMORY;);

for (int64_t c1i = 0; c1i < c1; c1i++) {
for (int64_t hi = 0; hi < h; hi++) {
@@ -293,14 +301,22 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n);
errno_t ret;
errno_t ret = EOK;
if (pad_zero) {
ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0,
static_cast<size_t>(data_size));
} else {
if (protected_size < data_size) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
protected_size, data_size);
return INTERNAL_ERROR;
}
int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i);
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size),
args.data + src_idx * data_size, static_cast<size_t>(data_size));
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset);
const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size);
for (int64_t index = 0; index < data_size; index++) {
*dst_data++ = *src_data++;
}
}
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
@@ -337,16 +353,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
return PARAM_INVALID;
}

if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatFromNchwToFz(args, result);
if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatNhwcToFz(args, result);
}

if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatHwcnToFz(args, result);
}

if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatNhwcToFz(args, result);
if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatFromNchwToFz(args, result);
}

return UNSUPPORTED;
@@ -358,14 +374,14 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i
return UNSUPPORTED;
}

if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNchwToFz(src_shape, data_type, dst_shape);
if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNhwcToFz(src_shape, data_type, dst_shape);
}
if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeHwcnToFz(src_shape, data_type, dst_shape);
}
if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNhwcToFz(src_shape, data_type, dst_shape);
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNchwToFz(src_shape, data_type, dst_shape);
}

return UNSUPPORTED;
@@ -374,6 +390,5 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NCHW, FORMAT_FRACTAL_Z)
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_HWCN, FORMAT_FRACTAL_Z)
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NHWC, FORMAT_FRACTAL_Z)

} // namespace formats
} // namespace ge

+ 1
- 3
src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc View File

@@ -39,7 +39,6 @@
namespace ge {
namespace formats {
namespace {

constexpr int64_t kMaxDimsNumC = 4;

Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; }
@@ -109,7 +108,7 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) {
return NOT_CHANGED;
}

/* prepare for padding in chw*/
// prepare for padding in chw
int64_t tmp = h * w * c;
int64_t n_o = Ceil(n, static_cast<int64_t>(c0));
int64_t c_o = c0;
@@ -309,6 +308,5 @@ Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vecto
}

REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04)

} // namespace formats
} // namespace ge

+ 0
- 2
src/ge/common/formats/utils/formats_definitions.h View File

@@ -19,7 +19,6 @@

namespace ge {
namespace formats {

static const int kCubeSize = 16;
static const int kNiSize = 16;
static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL;
@@ -47,7 +46,6 @@ enum FracZDimIndex { kFracZHWC1, kFracZN0, kFracZNi, kFracZC0, kFracZDimsNum };
enum DhwcnDimIndex { kDhwcnD, kDhwcnH, kDhwcnW, kDhwcnC, kDhwcnN, kDhwcnDimsNum };

enum DhwncDimIndex { kDhwncD, kDhwncH, kDhwncW, kDhwncN, kDhwncC, kDhwncDimsNum };

} // namespace formats
} // namespace ge
#endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_

+ 0
- 2
src/ge/common/formats/utils/formats_trans_utils.h View File

@@ -21,7 +21,6 @@
#include <sstream>
#include <string>
#include <vector>

#include "external/graph/types.h"
#include "graph/ge_tensor.h"

@@ -69,7 +68,6 @@ T Ceil(T n1, T n2) {
}
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
}

} // namespace formats
} // namespace ge
#endif // GE_COMMON_FORMATS_UTILS_FORMATS_TRANS_UTILS_H_

+ 1
- 1
src/ge/common/fp16_t.h View File

@@ -600,5 +600,5 @@ int16_t GetManBitLength(T man) {
}
return len;
}
}; // namespace ge
} // namespace ge
#endif // GE_COMMON_FP16_T_H_

+ 81
- 0
src/ge/common/ge/op_tiling_manager.cc View File

@@ -0,0 +1,81 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/ge/op_tiling_manager.h"
#include "framework/common/debug/log.h"
#include <string>

namespace {
const char *const kEnvName = "ASCEND_OPP_PATH";
const std::string kDefaultPath = "/usr/local/Ascend/opp";
const std::string kDefaultBuiltInTilingPath = "/op_impl/built-in/liboptiling.so";
const std::string kDefaultCustomTilingPath = "/op_impl/custom/liboptiling.so";
const uint8_t kPrefixIndex = 9;
} // namespace

namespace ge {
void OpTilingManager::ClearHandles() noexcept {
for (const auto &handle : handles_) {
if (dlclose(handle.second) != 0) {
GELOGE(FAILED, "Failed to close handle of %s: %s", handle.first.c_str(), dlerror());
}
}
handles_.clear();
}

OpTilingManager::~OpTilingManager() { ClearHandles(); }

std::string OpTilingManager::GetPath() {
const char *opp_path_env = std::getenv(kEnvName);
std::string opp_path = kDefaultPath;
if (opp_path_env != nullptr) {
char resolved_path[PATH_MAX];
if (realpath(opp_path_env, resolved_path) == NULL) {
GELOGE(PARAM_INVALID, "Failed load tiling lib as env 'ASCEND_OPP_PATH'(%s) is invalid path.", opp_path_env);
return std::string();
}
opp_path = resolved_path;
}
return opp_path;
}

void OpTilingManager::LoadSo() {
std::string opp_path = GetPath();
if (opp_path.empty()) {
GELOGW("Skip load tiling lib.");
return;
}
std::string built_in_tiling_lib = opp_path + kDefaultBuiltInTilingPath;
std::string custom_tiling_lib = opp_path + kDefaultCustomTilingPath;
std::string built_in_name = kDefaultBuiltInTilingPath.substr(kPrefixIndex);
std::string custom_name = kDefaultCustomTilingPath.substr(kPrefixIndex);

void *handle_bi = dlopen(built_in_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle_bi == nullptr) {
GELOGW("Failed to dlopen %s!", dlerror());
} else {
handles_[built_in_name] = handle_bi;
}

void *handle_ct = dlopen(custom_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle_ct == nullptr) {
GELOGW("Failed to dlopen %s!", dlerror());
} else {
handles_[custom_name] = handle_ct;
}
}

} // namespace ge

+ 38
- 0
src/ge/common/ge/op_tiling_manager.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_COMMON_GE_OP_TILING_MANAGER_H_
#define GE_COMMON_GE_OP_TILING_MANAGER_H_

#include <map>

namespace ge {
using SoToHandleMap = std::map<std::string, void *>;

class OpTilingManager {
public:
OpTilingManager() = default;
~OpTilingManager();
void LoadSo();

private:
static std::string GetPath();
void ClearHandles() noexcept;
SoToHandleMap handles_;
};
} // namespace ge

#endif // GE_COMMON_GE_OP_TILING_MANAGER_H_

+ 5
- 2
src/ge/common/ge/tbe_plugin_manager.cc View File

@@ -182,7 +182,7 @@ void TBEPluginManager::GetCustomOpPath(std::string &customop_path) {
}

void TBEPluginManager::LoadCustomOpLib() {
LoadPluginSo();
LoadPluginSo(options_);

std::vector<OpRegistrationData> registration_datas = domi::OpRegistry::Instance()->registrationDatas;
GELOGI("The size of registration_datas is: %zu", registration_datas.size());
@@ -193,10 +193,13 @@ void TBEPluginManager::LoadCustomOpLib() {
}
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPluginSo() {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPluginSo(
const std::map<string, string> &options) {
vector<string> file_list;
string caffe_parser_path;
std::string plugin_path;

options_ = options;
GetCustomOpPath(plugin_path);

// Whether there are files in the plugin so path


+ 1
- 1
src/ge/common/ge/tbe_plugin_manager.h View File

@@ -48,7 +48,7 @@ class TBEPluginManager {

static void InitPreparation(const std::map<string, string> &options);

void LoadPluginSo();
void LoadPluginSo(const std::map<string, string> &options);

private:
TBEPluginManager() = default;


+ 1
- 0
src/ge/common/ge_common.mk View File

@@ -36,6 +36,7 @@ GE_COMMON_LOCAL_SRC_FILES := \
properties_manager.cc \
types.cc\
model_parser/base.cc \
model_parser/graph_parser_util.cc \
tbe_kernel_store.cc \
op/attr_value_util.cc \
op/ge_op_utils.cc \


+ 31
- 125
src/ge/common/helper/model_helper.cc View File

@@ -17,6 +17,7 @@
#include "framework/common/helper/model_helper.h"

#include "common/ge/ge_util.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/log.h"
#include "framework/common/util.h"
#include "framework/common/debug/ge_log.h"
@@ -89,10 +90,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod
}
}
auto ge_model_weight = ge_model->GetWeight();
GELOGI("WEIGHTS_DATA size is %zu , %p", ge_model_weight.GetSize(), ge_model_weight.GetData());
if (SaveModelPartition(om_file_save_helper, ModelPartitionType::WEIGHTS_DATA, ge_model_weight.GetData(),
ge_model_weight.GetSize()) != SUCCESS) {
GELOGW("Add weight partition failed"); // weight is not necessary
GELOGI("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData());
// weight is not necessary
if (ge_model_weight.GetSize() > 0) {
GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, ModelPartitionType::WEIGHTS_DATA,
ge_model_weight.GetData(), ge_model_weight.GetSize()),
"Add weight partition failed");
}

TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore();
@@ -238,44 +241,48 @@ ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::strin

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(const ge::ModelData &model_data) {
if (model_data.model_data == nullptr || model_data.model_len == 0) {
GELOGE(FAILED, "Model_data is nullptr, or model_data_size is 0");
return FAILED;
GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "Model_data is nullptr, or model_data_size is 0");
return GE_EXEC_MODEL_DATA_SIZE_INVALID;
}

if (is_assign_model_) {
GELOGE(FAILED, "Model helper has already loaded!");
return FAILED;
GELOGE(GE_EXEC_LOAD_MODEL_REPEATED, "Model helper has already loaded!");
return GE_EXEC_LOAD_MODEL_REPEATED;
}

if (ReleaseLocalModelData() != SUCCESS) {
GELOGE(FAILED, "ReleaseLocalModelData failed.");
return FAILED;
GELOGE(INTERNAL_ERROR, "ReleaseLocalModelData failed.");
return INTERNAL_ERROR;
}

Status status = ge::DavinciModelParser::ParseModelContent(model_data, model_addr_tmp_, model_len_tmp_);
if (ge::DavinciModelParser::ParseModelContent(model_data, model_addr_tmp_, model_len_tmp_) != SUCCESS) {
GELOGE(FAILED, "Parse model content failed!");
return FAILED;
GELOGE(status, "Parse model content failed!");
return status;
}

file_header_ = reinterpret_cast<ModelFileHeader *>(model_data.model_data);

OmFileLoadHelper om_load_helper;
if (om_load_helper.Init(model_addr_tmp_, model_len_tmp_) != SUCCESS) {
GELOGE(FAILED, "Om_load_helper init failed");
status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_);
if (status != SUCCESS) {
GELOGE(status, "Om_load_helper init failed");
model_addr_tmp_ = nullptr;
return FAILED;
return status;
}
auto partition_table = reinterpret_cast<ModelPartitionTable *>(model_addr_tmp_);
if (partition_table->num == kOriginalOmPartitionNum) {
GELOGE(FAILED, "om model is error,please use executable om model");
return FAILED;
model_addr_tmp_ = nullptr;
GELOGE(GE_EXEC_MODEL_PARTITION_NUM_INVALID, "om model is error,please use executable om model");
return GE_EXEC_MODEL_PARTITION_NUM_INVALID;
}
// Encrypt model need to del temp model/no encrypt model don't need to del model
model_addr_tmp_ = nullptr;

if (GenerateGeModel(om_load_helper) != SUCCESS) {
GELOGE(FAILED, "GenerateGeModel failed");
return FAILED;
status = GenerateGeModel(om_load_helper);
if (status != SUCCESS) {
GELOGE(status, "GenerateGeModel failed");
return status;
}

is_assign_model_ = true;
@@ -287,19 +294,19 @@ Status ModelHelper::GenerateGeModel(OmFileLoadHelper &om_load_helper) {
GE_CHECK_NOTNULL(model_);
Status ret = LoadModelData(om_load_helper);
if (ret != SUCCESS) {
return ret;
return GE_EXEC_LOAD_MODEL_PARTITION_FAILED;
}
ret = LoadWeights(om_load_helper);
if (ret != SUCCESS) {
return ret;
return GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED;
}
ret = LoadTask(om_load_helper);
if (ret != SUCCESS) {
return ret;
return GE_EXEC_LOAD_TASK_PARTITION_FAILED;
}
ret = LoadTBEKernelStore(om_load_helper);
if (ret != SUCCESS) {
return ret;
return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED;
}
return SUCCESS;
}
@@ -390,107 +397,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo
return out_model;
}

// Transit func for model to ge_model. It will be removed when load and build support ge_model in future
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransModelToGeModel(const ModelPtr &model,
GeModelPtr &ge_model) {
if (model == nullptr) {
GELOGE(FAILED, "Model is null");
return FAILED;
}
ge_model = ge::MakeShared<ge::GeModel>();
GE_CHECK_NOTNULL(ge_model);
ge_model->SetGraph(model->GetGraph());
ge_model->SetName(model->GetName());
ge_model->SetVersion(model->GetVersion());
ge_model->SetPlatformVersion(model->GetPlatformVersion());
ge_model->SetAttr(model->MutableAttrMap());

// Copy weight info
auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph());
// ge::Buffer weight;
ge::Buffer weight;
(void)ge::AttrUtils::GetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, weight);
ge_model->SetWeight(weight);
// Copy task info
if (model->HasAttr(MODEL_ATTR_TASKS)) {
ge::Buffer task_buffer;
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetZeroCopyBytes(model, MODEL_ATTR_TASKS, task_buffer), FAILED,
"Get bytes failed.");

std::shared_ptr<ModelTaskDef> task = ge::MakeShared<ModelTaskDef>();
GE_CHECK_NOTNULL(task);
GE_IF_BOOL_EXEC(task_buffer.GetData() == nullptr, GELOGE(FAILED, "Get data fail"); return FAILED);
GE_IF_BOOL_EXEC(task_buffer.GetSize() == 0, GELOGE(FAILED, "Get size fail"); return FAILED);

GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_buffer.GetData(), static_cast<int>(task_buffer.GetSize()), task.get()),
return INTERNAL_ERROR, "ReadProtoFromArray failed.");

ge_model->SetModelTaskDef(task);
}
// Copy tbe kernel info
// TBEKernelStore kernel_store;
TBEKernelStore kernel_store;
if (compute_graph != nullptr && compute_graph->GetDirectNodesSize() != 0) {
for (const ge::NodePtr &n : compute_graph->GetDirectNode()) {
auto node_op_desc = n->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);
TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue);
kernel_store.AddTBEKernel(tbe_kernel);
GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str());
}
}
if (!kernel_store.Build()) {
GELOGE(FAILED, "TBE Kernels store build failed!");
return FAILED;
}
ge_model->SetTBEKernelStore(kernel_store);

return SUCCESS;
}

// trasit func for ge_model to Model. will be removed when load and build support ge_model in future
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransGeModelToModel(const GeModelPtr &ge_model,
ModelPtr &model) {
if (ge_model == nullptr) {
GELOGE(FAILED, "Ge_model is null");
return FAILED;
}
model = ge::MakeShared<ge::Model>();
GE_CHECK_NOTNULL(model);
model->SetGraph(ge_model->GetGraph());
model->SetName(ge_model->GetName());
model->SetVersion(ge_model->GetVersion());
model->SetPlatformVersion(ge_model->GetPlatformVersion());
model->SetAttr(ge_model->MutableAttrMap());
// Copy weight info
auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph());
bool ret = ge::AttrUtils::SetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, ge_model->GetWeight());
if (!ret) {
GELOGE(FAILED, "Copy weight buffer failed!");
return FAILED;
}
// Copy task info
std::shared_ptr<ModelTaskDef> model_task = ge_model->GetModelTaskDefPtr();

if (model_task != nullptr) {
int size = model_task->ByteSize();
ge::Buffer buffer(static_cast<size_t>(size));
if (buffer.GetSize() == 0) {
GELOGE(MEMALLOC_FAILED, "alloc model attr task buffer failed!");
return MEMALLOC_FAILED;
}
// no need to check value
(void)model_task->SerializePartialToArray(buffer.GetData(), size);
ret = ge::AttrUtils::SetZeroCopyBytes(model, MODEL_ATTR_TASKS, std::move(buffer));
if (!ret) {
GELOGE(FAILED, "Copy task buffer failed!");
return FAILED;
}
}
return SUCCESS;
}

Status ModelHelper::ReleaseLocalModelData() noexcept {
Status result = SUCCESS;
if (model_addr_tmp_ != nullptr) {


+ 17
- 14
src/ge/common/helper/om_file_helper.cc View File

@@ -41,8 +41,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(c

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data,
const uint32_t model_data_size) {
if (LoadModelPartitionTable(model_data, model_data_size) != SUCCESS) {
return FAILED;
Status status = LoadModelPartitionTable(model_data, model_data_size);
if (status != SUCCESS) {
return status;
}
is_inited_ = true;
return SUCCESS;
@@ -66,7 +67,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod
}

if (!found) {
if (type != ModelPartitionType::TBE_KERNELS) {
if (type != ModelPartitionType::TBE_KERNELS && type != ModelPartitionType::WEIGHTS_DATA) {
GELOGE(FAILED, "GetModelPartition:type:%d is not in partition_datas!", static_cast<int>(type));
return FAILED;
}
@@ -83,7 +84,9 @@ Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const {

// Model length too small
if (model.model_len < (sizeof(ModelFileHeader) + sizeof(ModelPartitionTable))) {
GELOGE(PARAM_INVALID, "Invalid model. length < sizeof(ModelFileHeader) + sizeof(ModelPartitionTable).");
GELOGE(PARAM_INVALID,
"Invalid model. length[%u] < sizeof(ModelFileHeader)[%zu] + sizeof(ModelPartitionTable)[%zu].",
model.model_len, sizeof(ModelFileHeader), sizeof(ModelPartitionTable));
return PARAM_INVALID;
}

@@ -93,9 +96,9 @@ Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const {
if ((model_header->length != model.model_len - sizeof(ModelFileHeader)) ||
(MODEL_FILE_MAGIC_NUM != model_header->magic)) {
GELOGE(PARAM_INVALID,
"Invalid model. file_header->length(%u) + sizeof(ModelFileHeader)(%zu) != model->model_len(%u) || "
"MODEL_FILE_MAGIC_NUM != file_header->magic",
model_header->length, sizeof(ModelFileHeader), model.model_len);
"Invalid model. file_header->length[%u] + sizeof(ModelFileHeader)[%zu] != model->model_len[%u] || "
"MODEL_FILE_MAGIC_NUM[%u] != file_header->magic[%u]",
model_header->length, sizeof(ModelFileHeader), model.model_len, MODEL_FILE_MAGIC_NUM, model_header->magic);
return PARAM_INVALID;
}
return SUCCESS;
@@ -112,16 +115,16 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint
// Original model partition include graph-info
if ((partition_table->num != PARTITION_SIZE) && (partition_table->num != (PARTITION_SIZE - 1)) &&
(partition_table->num != 1)) {
GELOGE(PARAM_INVALID, "Invalid partition_table->num:%u", partition_table->num);
return PARAM_INVALID;
GELOGE(GE_EXEC_MODEL_PARTITION_NUM_INVALID, "Invalid partition_table->num:%u", partition_table->num);
return GE_EXEC_MODEL_PARTITION_NUM_INVALID;
}
size_t mem_offset = SIZE_OF_MODEL_PARTITION_TABLE(*partition_table);
GELOGI("ModelPartitionTable num :%u, ModelFileHeader length :%zu, ModelPartitionTable length :%zu",
partition_table->num, sizeof(ModelFileHeader), mem_offset);
if (model_data_size <= mem_offset) {
GELOGE(PARAM_INVALID, "invalid model data, partition_table->num:%u, model data size %u", partition_table->num,
model_data_size);
return PARAM_INVALID;
GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "invalid model data, partition_table->num:%u, model data size %u",
partition_table->num, model_data_size);
return GE_EXEC_MODEL_DATA_SIZE_INVALID;
}
for (uint32_t i = 0; i < partition_table->num; i++) {
ModelPartition partition;
@@ -131,9 +134,9 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint
context_.partition_datas_.push_back(partition);

if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) {
GELOGE(PARAM_INVALID, "The partition size %zu is greater than the model data size %u.",
GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.",
partition.size + mem_offset, model_data_size);
return PARAM_INVALID;
return GE_EXEC_MODEL_DATA_SIZE_INVALID;
}
mem_offset += partition.size;
GELOGI("Partition, type:%d, size:%u", static_cast<int>(partition.type), partition.size);


+ 1
- 1
src/ge/common/math/fp16_math.h View File

@@ -92,5 +92,5 @@ fp16_t max(fp16_t fp1, fp16_t fp2);
/// @brief Calculate the minimum fp16_t of fp1 and fp2
/// @return Returns minimum fp16_t of fp1 and fp2
fp16_t min(fp16_t fp1, fp16_t fp2);
}; // namespace ge
} // namespace ge
#endif // GE_COMMON_MATH_FP16_MATH_H_

+ 0
- 2
src/ge/common/math_util.h View File

@@ -27,7 +27,6 @@
#include "mmpa/mmpa_api.h"

namespace ge {

/**
* @ingroup domi_calibration
* @brief Initializes an input array to a specified value
@@ -67,7 +66,6 @@ Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) {
}
return SUCCESS;
}

} // end namespace ge

#endif // GE_COMMON_MATH_UTIL_H_

+ 16
- 13
src/ge/common/model_parser/base.cc View File

@@ -35,15 +35,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFro
ge::ModelData &model_data) {
std::string real_path = RealPath(model_path);
if (real_path.empty()) {
GELOGE(PARAM_INVALID, "Model file path '%s' is invalid", model_path);
return PARAM_INVALID;
GELOGE(GE_EXEC_MODEL_PATH_INVALID, "Model file path '%s' is invalid", model_path);
return GE_EXEC_MODEL_PATH_INVALID;
}

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(model_path) == -1, return FAILED, "File size not valid.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(model_path) == -1, return GE_EXEC_READ_MODEL_FILE_FAILED,
"File size not valid.");

std::ifstream fs(real_path.c_str(), std::ifstream::binary);

GE_CHK_BOOL_RET_STATUS(fs.is_open(), FAILED, "Open file failed! path:%s", model_path);
GE_CHK_BOOL_RET_STATUS(fs.is_open(), GE_EXEC_READ_MODEL_FILE_FAILED, "Open file failed! path:%s", model_path);

// get length of file:
(void)fs.seekg(0, std::ifstream::end);
@@ -55,7 +56,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFro

char *data = new (std::nothrow) char[len];
if (data == nullptr) {
GELOGE(MEMALLOC_FAILED, "Load model From file failed, bad memory allocation occur. (need:%ld)", len);
GELOGE(MEMALLOC_FAILED, "Load model From file failed, bad memory allocation occur. (need:%u)", len);
return MEMALLOC_FAILED;
}

@@ -79,31 +80,33 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::ParseMo
GE_CHECK_NOTNULL(model.model_data);

// Model length too small
GE_CHK_BOOL_RET_STATUS(model.model_len >= sizeof(ModelFileHeader), PARAM_INVALID,
"Invalid model. length < sizeof(ModelFileHeader).");
GE_CHK_BOOL_RET_STATUS(model.model_len >= sizeof(ModelFileHeader), GE_EXEC_MODEL_DATA_SIZE_INVALID,
"Invalid model. Model data size %u must be greater than or equal to %zu.", model.model_len,
sizeof(ModelFileHeader));
// Get file header
auto file_header = reinterpret_cast<ModelFileHeader *>(model.model_data);
// Determine whether the file length and magic number match
GE_CHK_BOOL_RET_STATUS(
file_header->length == model.model_len - sizeof(ModelFileHeader) && file_header->magic == MODEL_FILE_MAGIC_NUM,
PARAM_INVALID,
"Invalid model. file_header->length + sizeof(ModelFileHeader) != model->model_len || MODEL_FILE_MAGIC_NUM != "
"file_header->magic");
GE_EXEC_MODEL_DATA_SIZE_INVALID,
"Invalid model. file_header->length[%u] + sizeof(ModelFileHeader)[%zu] != model->model_len[%u] || "
"MODEL_FILE_MAGIC_NUM[%u] != file_header->magic[%u]",
file_header->length, sizeof(ModelFileHeader), model.model_len, MODEL_FILE_MAGIC_NUM, file_header->magic);

Status res = SUCCESS;

// Get data address
uint8_t *data = reinterpret_cast<uint8_t *>(model.model_data) + sizeof(ModelFileHeader);
if (file_header->is_encrypt == ModelEncryptType::UNENCRYPTED) { // Unencrypted model
GE_CHK_BOOL_RET_STATUS(model.key.empty(), PARAM_INVALID,
GE_CHK_BOOL_RET_STATUS(model.key.empty(), GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION,
"Invalid param. model is unencrypted, but key is not empty.");

model_data = data;
model_len = file_header->length;
GELOGI("Model_len is %u, model_file_head_len is %zu.", model_len, sizeof(ModelFileHeader));
} else {
GELOGE(PARAM_INVALID, "Invalid model. ModelEncryptType not supported.");
res = PARAM_INVALID;
GELOGE(GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, "Invalid model. ModelEncryptType not supported.");
res = GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION;
}

return res;


+ 501
- 0
src/ge/common/model_parser/graph_parser_util.cc View File

@@ -0,0 +1,501 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph_parser_util.h"
#include <memory>
#include "common/auth/file_saver.h"
#include "common/convert/pb2json.h"
#include "common/debug/log.h"
#include "common/debug/memory_dumper.h"
#include "common/model_parser/base.h"
#include "common/model_saver.h"
#include "common/properties_manager.h"
#include "common/string_util.h"
#include "common/types.h"
#include "common/util.h"
#include "common/util/error_manager/error_manager.h"
#include "external/register/register_types.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/compute_graph.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/optimize/common/params.h"
#include "graph/utils/type_utils.h"
#include "omg/omg_inner_types.h"
#include "omg/parser/model_parser.h"
#include "omg/parser/parser_factory.h"
#include "omg/parser/weights_parser.h"
#include "parser/common/pre_checker.h"
#include "proto/ge_ir.pb.h"
#include "register/op_registry.h"

namespace ge {
namespace {
// The function is incomplete. Currently, only l2_optimize, off_optimize is supported.
const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\"";
const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\"";
const char *const kSplitError1 = "size not equal to 2 split by \":\"";
const char *const kEmptyError = "can not be empty";
const char *const kFloatNumError = "exist float number";
const char *const kDigitError = "is not digit";
const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\"";
const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8";
const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes.";

vector<string> SplitInputShape(const std::string &input_shape) {
vector<string> shape_pair_vec;
size_t pos = input_shape.rfind(":");
if (pos != std::string::npos) {
shape_pair_vec.emplace_back(input_shape.substr(0, pos));
shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos));
}
return shape_pair_vec;
}

static std::map<std::string, ge::DataType> output_type_str_to_datatype = {
{"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}};

static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_param) {
if ((s == "true") || (s == "false")) {
return true;
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, {atc_param, s});
GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str());
return false;
}
}

bool CheckDigitStr(std::string &str) {
for (char c : str) {
if (!isdigit(c)) {
GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str());
return false;
}
}
return true;
}

Status StringToInt(std::string &str, int32_t &value) {
try {
if (!CheckDigitStr(str)) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", str, "is not positive integer"});
return PARAM_INVALID;
}
value = stoi(str);
} catch (std::invalid_argument &) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str});
return PARAM_INVALID;
} catch (std::out_of_range &) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str});
return PARAM_INVALID;
}
return SUCCESS;
}

Status VerifyOutputTypeAndOutNodes(std::vector<std::string> &out_type_vec) {
std::vector<std::pair<std::string, int32_t>> user_out_nodes = domi::GetContext().user_out_nodes;
std::set<std::string> out_nodes_info;
for (uint32_t i = 0; i < user_out_nodes.size(); ++i) {
// out_nodes set should include output_type and output_format
std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second);
out_nodes_info.emplace(tmp);
}
for (uint32_t i = 0; i < out_type_vec.size(); ++i) {
if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", out_type_vec[i], kOutputTypeError});
GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError);
return domi::FAILED;
}
}
return domi::SUCCESS;
}

Status ParseOutputType(const std::string &output_type, std::map<std::string, vector<uint32_t>> &out_type_index_map,
std::map<std::string, vector<ge::DataType>> &out_type_dt_map) {
if (output_type.find(':') == std::string::npos) {
GELOGI("output_type is not multiple nodes, means all out nodes");
auto it = output_type_str_to_datatype.find(output_type);
if (it == output_type_str_to_datatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", output_type, kOutputTypeSupport});
GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport);
return domi::FAILED;
}
return domi::SUCCESS;
}
std::vector<std::string> out_type_vec;
vector<string> nodes_v = StringUtils::Split(output_type, ';');
for (const string &node : nodes_v) {
vector<string> node_index_type_v = StringUtils::Split(node, ':');
if (node_index_type_v.size() != 3) { // The size must be 3.
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", node, kOutputTypeSample});
GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample);
return domi::FAILED;
}
ge::DataType tmp_dt;
std::string node_name = StringUtils::Trim(node_index_type_v[0]);
std::string index_str = StringUtils::Trim(node_index_type_v[1]);
int32_t index;
if (StringToInt(index_str, index) != SUCCESS) {
GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str());
return domi::FAILED;
}
std::string dt_value = StringUtils::Trim(node_index_type_v[2]);
auto it = output_type_str_to_datatype.find(dt_value);
if (it == output_type_str_to_datatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", dt_value, kOutputTypeSupport});
GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport);
return domi::FAILED;
} else {
tmp_dt = it->second;
}
out_type_vec.push_back(node_name + ":" + index_str);
auto it_index = out_type_index_map.find(node_name);
if (it_index == out_type_index_map.end()) {
vector<uint32_t> tmp_vec;
tmp_vec.push_back(index);
out_type_index_map.emplace(node_name, tmp_vec);
} else {
it_index->second.push_back(index);
}

auto it_dt = out_type_dt_map.find(node_name);
if (it_dt == out_type_dt_map.end()) {
vector<ge::DataType> tmp_vec;
tmp_vec.push_back(tmp_dt);
out_type_dt_map.emplace(node_name, tmp_vec);
} else {
it_dt->second.push_back(tmp_dt);
}
}
return VerifyOutputTypeAndOutNodes(out_type_vec);
}

Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) {
int32_t out_size = op_desc->GetOutputsSize();
if (index < 0 || index >= out_size) {
GELOGE(domi::FAILED,
"out_node [%s] output index:%d must be smaller "
"than node output size:%d and can not be negative!",
op_desc->GetName().c_str(), index, out_size);
std::string fail_reason = "output index:" + to_string(index) +
" must be smaller than output size:" + to_string(out_size) + " and can not be negative!";
ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"},
{"out_nodes", op_desc->GetName(), fail_reason});
return domi::FAILED;
}
return domi::SUCCESS;
}

Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
ge::OpDescPtr tmpDescPtr = node->GetOpDesc();
if (tmpDescPtr == nullptr) {
GELOGE(domi::FAILED, "Get outnode op desc fail.");
return domi::FAILED;
}
size_t size = tmpDescPtr->GetOutputsSize();
if (node->GetType() != NETOUTPUT) {
for (size_t index = 0; index < size; ++index) {
output_nodes_info.push_back(std::make_pair(node, index));
}
} else {
const auto in_anchors = node->GetAllInDataAnchors();
for (auto in_anchor : in_anchors) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
GELOGE(domi::FAILED, "Get leaf node op desc fail.");
return domi::FAILED;
}
auto out_node = out_anchor->GetOwnerNode();
output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx()));
}
}
return SUCCESS;
}

void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) {
output_nodes_name.clear();
if (domi::GetContext().out_top_names.empty()) {
// tf process, no top name.
for (const auto output_node_info : output_nodes_info) {
std::string node_name = output_node_info.first->GetName();
int32_t index = output_node_info.second;
output_nodes_name.push_back(node_name + ":" + std::to_string(index));
}
return;
}
// caffe process, need add top name after node_name:index
for (size_t i = 0; i < output_nodes_info.size(); ++i) {
std::string node_name = output_nodes_info[i].first->GetName();
int32_t index = output_nodes_info[i].second;
if (i < domi::GetContext().out_top_names.size()) {
output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + domi::GetContext().out_top_names[i]);
} else {
GELOGW("Get top name of node [%s] fail.", node_name.c_str());
output_nodes_name.push_back(node_name + ":" + std::to_string(index));
}
}
}
} // namespace

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputFp16NodesFormat(const string &is_output_fp16) {
if (is_output_fp16.empty()) {
return SUCCESS;
}

vector<domiTensorFormat_t> &output_formats = domi::GetContext().output_formats;
output_formats.clear();
vector<string> node_format_vec = StringUtils::Split(is_output_fp16, ',');
for (auto &is_fp16 : node_format_vec) {
StringUtils::Trim(is_fp16);
if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) {
GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]",
is_output_fp16.c_str());
return PARAM_INVALID;
}
if (is_fp16 == "false") {
output_formats.push_back(DOMI_TENSOR_ND);
} else if (is_fp16 == "true") {
output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0);
}
}
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph,
const std::string &output_type,
const std::string &output) {
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);

std::vector<std::pair<std::string, int32_t>> user_out_nodes = domi::GetContext().user_out_nodes;
std::vector<domiTensorFormat_t> output_formats = domi::GetContext().output_formats;
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info;
std::vector<std::string> output_nodes_name;
std::map<std::string, vector<uint32_t>> out_type_index_map;
std::map<std::string, vector<ge::DataType>> out_type_dt_map;
if (!output_type.empty()) {
if (ParseOutputType(output_type, out_type_index_map, out_type_dt_map) != SUCCESS) {
GELOGE(domi::FAILED, "Parse output_type failed.");
return domi::FAILED;
}
}

// User declared outputs
for (uint32_t i = 0; i < user_out_nodes.size(); ++i) {
ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first);
if (out_node == nullptr) {
GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str());
return domi::FAILED;
}
auto op_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) {
GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str());
return domi::FAILED;
}
if (i < output_formats.size()) {
if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) {
GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str());
if (!ge::AttrUtils::SetBool(op_desc, "output_set_fp16_nc1hwc0", true)) {
GELOGW("The output node [%s] set NC1HWC0 failed", user_out_nodes[i].first.c_str());
}
}
}
auto it_index = out_type_index_map.find(user_out_nodes[i].first);
auto it_dt = out_type_dt_map.find(user_out_nodes[i].first);
if ((it_index != out_type_index_map.end()) && (it_dt != out_type_dt_map.end())) {
GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str());
(void)ge::AttrUtils::SetListDataType(op_desc, "_output_dt_list", it_dt->second);
(void)ge::AttrUtils::SetListInt(op_desc, "_output_dt_index", it_index->second);
}
output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second));
}
// default output node (leaf)
if (user_out_nodes.empty()) {
for (ge::NodePtr node : compute_graph->GetDirectNode()) {
if (!node->GetInDataNodes().empty() && node->GetOutDataNodes().empty()) {
Status ret = GetOutputLeaf(node, output_nodes_info);
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail.");
}
}
}
GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name);
compute_graph->SetGraphOutNodesInfo(output_nodes_info);
domi::GetContext().net_out_nodes = output_nodes_name;
return domi::SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ParseInputShape(
const string &input_shape, unordered_map<string, vector<int64_t>> &shape_map,
vector<pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input) {
vector<string> shape_vec = StringUtils::Split(input_shape, ';');
const int DEFAULT_SHAPE_PAIR_SIZE = 2;
for (const auto &shape : shape_vec) {
vector<string> shape_pair_vec = SplitInputShape(shape);
if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kSplitError1, kInputShapeSample1});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kSplitError1, kInputShapeSample1);
return false;
}
if (shape_pair_vec[1].empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kEmptyError, kInputShapeSample1});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kEmptyError, kInputShapeSample1);
return false;
}

vector<string> shape_value_strs = StringUtils::Split(shape_pair_vec[1], ',');
vector<int64_t> shape_values;
for (auto &shape_value_str : shape_value_strs) {
// stoul: The method may throw an exception: invalid_argument/out_of_range
if (std::string::npos != shape_value_str.find('.')) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kFloatNumError, kInputShapeSample2});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kFloatNumError, kInputShapeSample2);
return false;
}

long left_result = 0;
try {
left_result = stol(StringUtils::Trim(shape_value_str));
if (!shape_value_str.empty() && (shape_value_str.front() == '-')) {
// The value maybe dynamic shape [-1], need substr it and verify isdigit.
shape_value_str = shape_value_str.substr(1);
}
for (char c : shape_value_str) {
if (!isdigit(c)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kDigitError, kInputShapeSample2});
GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str());
return false;
}
}
} catch (const std::out_of_range &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"},
{"input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str());
return false;
} catch (const std::invalid_argument &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"},
{"input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str());
return false;
} catch (...) {
ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"},
{"input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str());
return false;
}
int64_t result = left_result;
// - 1 is not currently supported
if (!is_dynamic_input && result <= 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)});
GELOGW(
"Input parameter[--input_shape]’s shape value[%s] is invalid, "
"expect positive integer, but value is %ld.",
shape.c_str(), result);
return false;
}
shape_values.push_back(result);
}

shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
}

return true;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputNodes(const string &out_nodes) {
try {
// parse output node
if (!out_nodes.empty()) {
domi::GetContext().out_nodes_map.clear();
domi::GetContext().user_out_nodes.clear();

vector<string> nodes_v = StringUtils::Split(out_nodes, ';');
for (const string &node : nodes_v) {
vector<string> key_value_v = StringUtils::Split(node, ':');
if (key_value_v.size() != 2) { // The size must be 2.
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""});
GELOGE(PARAM_INVALID,
"The input format of --out_nodes is invalid, the correct format is "
"\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.",
node.c_str());
return PARAM_INVALID;
}
auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]);
// stoi: The method may throw an exception: invalid_argument/out_of_range
if (!CheckDigitStr(key_value_v[1])) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--out_nodes", out_nodes, "is not positive integer"});
GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str());
return PARAM_INVALID;
}
int32_t index = stoi(StringUtils::Trim(key_value_v[1]));
if (iter != domi::GetContext().out_nodes_map.end()) {
iter->second.emplace_back(index);
} else {
std::vector<int32_t> index_v;
index_v.emplace_back(index);
domi::GetContext().out_nodes_map.emplace(key_value_v[0], index_v);
}
domi::GetContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index));
}
}
} catch (std::invalid_argument &) {
GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes});
return PARAM_INVALID;
} catch (std::out_of_range &) {
GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes});
return PARAM_INVALID;
}
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOpConf(const char *op_conf) {
if (op_conf != nullptr && *op_conf != '\0') {
// divided by ":"
PropertiesManager::Instance().SetPropertyDelimiter(OP_CONF_DELIMITER);
// Parsing the op_conf configuration item file
if (!PropertiesManager::Instance().Init(op_conf)) {
GELOGE(FAILED, "op_name_map init failed!");
return FAILED;
}
// Return map and put it into ATC global variable
domi::GetContext().op_conf_map = PropertiesManager::Instance().GetPropertyMap();
}
return SUCCESS;
}
} // namespace ge

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save