@@ -16,6 +16,7 @@ | |||
cmake_minimum_required(VERSION 3.14) | |||
project (GraphEngine[CXX]) | |||
set(CMAKE_CXX_STANDARD 14) | |||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) | |||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) | |||
set(GE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) | |||
@@ -71,6 +72,7 @@ elseif(DEFINED ENV{D_LINK_PATH}) | |||
find_library(register libregister.so ${GE_LIB_PATH}) | |||
find_library(hccl libhccl.so ${GE_LIB_PATH}) | |||
find_library(resource libresource.so ${GE_LIB_PATH}) | |||
find_library(error_manager liberror_manager.so ${GE_LIB_PATH}) | |||
else() | |||
# Ascend mode | |||
if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||
@@ -88,6 +90,7 @@ else() | |||
find_library(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||
find_library(register libregister.so ${ASCEND_RUNTIME_DIR}) | |||
find_library(resource libresource.so ${ASCEND_RUNTIME_DIR}) | |||
find_library(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | |||
endif() | |||
# add compile flags | |||
@@ -44,6 +44,9 @@ class GraphOptimizer { | |||
// optimize original graph, using in graph preparation stage | |||
virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | |||
// optimize original graph, using for conversion operator insert in graph preparation stage | |||
virtual Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) { return SUCCESS; } | |||
// optimize fused graph | |||
virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | |||
@@ -0,0 +1,36 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef COMPRESS_H | |||
#define COMPRESS_H | |||
#include <uchar.h> | |||
enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; | |||
struct CompressConfig { | |||
size_t inputSize; // length of data to compress | |||
size_t engineNum; // how many decompress engines | |||
size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) | |||
size_t channel; // channels of L2 or DDR. For load balance | |||
size_t fractalSize; // size of compressing block | |||
bool isTight; // whether compose compressed data tightly | |||
}; | |||
CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, | |||
size_t& compressedLength); | |||
#endif // COMPRESS_H |
@@ -0,0 +1,83 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef ERROR_MANAGER_H_ | |||
#define ERROR_MANAGER_H_ | |||
#include <map> | |||
#include <string> | |||
#include <vector> | |||
class ErrorManager { | |||
public: | |||
/// | |||
/// @brief Obtain ErrorManager instance | |||
/// @return ErrorManager instance | |||
/// | |||
static ErrorManager &GetInstance(); | |||
/// | |||
/// @brief init | |||
/// @param [in] path current so path | |||
/// @return int 0(success) -1(fail) | |||
/// | |||
int Init(std::string path); | |||
/// | |||
/// @brief Report error message | |||
/// @param [in] errCode error code | |||
/// @param [in] mapArgs parameter map | |||
/// @return int 0(success) -1(fail) | |||
/// | |||
int ReportErrMessage(std::string error_code, const std::map<std::string, std::string> &args_map); | |||
/// @brief output error message | |||
/// @param [in] handle print handle | |||
/// @return int 0(success) -1(fail) | |||
/// | |||
int OutputErrMessage(int handle); | |||
/// @brief Report error message | |||
/// @param [in] vector parameter key, vector parameter value | |||
/// | |||
void ATCReportErrMessage(std::string error_code, const std::vector<std::string> &key = {}, | |||
const std::vector<std::string> &value = {}); | |||
private: | |||
struct ErrorInfo { | |||
std::string error_id; | |||
std::string error_message; | |||
std::vector<std::string> arglist; | |||
}; | |||
ErrorManager() {} | |||
~ErrorManager() {} | |||
ErrorManager(const ErrorManager &) = delete; | |||
ErrorManager(ErrorManager &&) = delete; | |||
ErrorManager &operator=(const ErrorManager &) = delete; | |||
ErrorManager &operator=(ErrorManager &&) = delete; | |||
int ParseJsonFile(std::string path); | |||
int ReadJsonFile(const std::string &file_path, void *handle); | |||
bool is_init_ = false; | |||
std::map<std::string, ErrorInfo> error_map_; | |||
std::vector<std::string> error_message_evc_; | |||
}; | |||
#endif // ERROR_MANAGER_H_ |
@@ -65,6 +65,8 @@ class PlatformInfoManager { | |||
void ParseUBOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||
void ParseUnzipOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||
void ParseAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||
void ParseBufferOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||
@@ -65,6 +65,10 @@ typedef struct tagAiCoreSpec { | |||
uint64_t ubbankNum; | |||
uint64_t ubburstInOneBlock; | |||
uint64_t ubbankGroupNum; | |||
uint32_t unzipEngines; | |||
uint32_t unzipMaxRatios; | |||
uint32_t unzipChannels; | |||
uint8_t unzipIsTight; | |||
} AiCoreSpec; | |||
typedef struct tagAiCoreMemoryRates { | |||
@@ -82,14 +82,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||
/// @brief run graph in the session with specific session id asynchronously | |||
/// @param [in] graphId: graph id | |||
/// @param [in] inputs: input data | |||
/// @param [out] outputs: output data | |||
/// @param [out] callback: callback while runing graph has been finished. | |||
/// The callback function will not be checked. | |||
/// Please ensure that the implementation of the function is trusted. | |||
/// @return Status result of function | |||
/// | |||
Status RunGraphAsync(uint32_t graphId, const std::vector<ge::TensorInfo> &inputs, | |||
std::vector<ge::TensorInfo> &outputs, std::function<void(Status)> callback); | |||
Status RunGraphAsync(uint32_t graphId, const std::vector<ge::InputTensorInfo> &inputs, RunAsyncCallback callback); | |||
/// | |||
/// @ingroup ge_graph | |||
@@ -21,6 +21,8 @@ | |||
#include <string> | |||
#include <vector> | |||
#include <set> | |||
#include <functional> | |||
#include <memory> | |||
namespace ge { | |||
// Option key: graph run mode | |||
@@ -40,6 +42,12 @@ const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | |||
const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | |||
const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | |||
const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | |||
const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; | |||
const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | |||
const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | |||
// profiling flag | |||
const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; | |||
const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; | |||
// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | |||
const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | |||
const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | |||
@@ -173,6 +181,9 @@ const std::string AICORE_NUM = "ge.aicoreNum"; | |||
// Configure L1FUSION | |||
const std::string L1_FUSION = "ge.l1Fusion"; | |||
// Configure l1,l2,and others optimize option | |||
const std::string BUFFER_OPTIMIZE = "ge.bufferOptimize"; | |||
// Configure Small Channel flag | |||
const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | |||
@@ -188,6 +199,9 @@ const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | |||
// Save original model file name | |||
const std::string ORIGINAL_MODEL_FILE = "ge.originalModelFile"; | |||
// FE enable quant optimize | |||
const std::string QUANT_OPTIMIZE = "ge.quantOptimize"; | |||
const char *const OPTION_GE_MAX_DUMP_FILE_NUM = "ge.maxDumpFileNum"; | |||
const char *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; | |||
const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | |||
@@ -196,36 +210,49 @@ const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | |||
// Its value should be "0" or "1", default value is "1" | |||
const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; | |||
// Configure whether to use single stream. | |||
// Its value should be "true" or "false", default value is "false" | |||
const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; | |||
// Graph run mode | |||
enum GraphRunMode { PREDICTION = 0, TRAIN }; | |||
// Data description | |||
struct DataDesc { | |||
void *data = nullptr; // data address | |||
uint32_t length = 0; // data size | |||
bool isDataSupportMemShare = false; | |||
// Input/Output tensor info | |||
struct InputTensorInfo { | |||
uint32_t data_type; // data type | |||
std::vector<int64_t> dims; // shape description | |||
void *data; // tensor data | |||
int64_t length; // tensor length | |||
}; | |||
// Input/Output shape description | |||
struct ShapeDesc { | |||
int64_t num = 0; | |||
int64_t channel = 0; | |||
int64_t height = 0; | |||
int64_t width = 0; | |||
std::vector<int64_t> dims; | |||
struct OutputTensorInfo { | |||
uint32_t data_type; // data type | |||
std::vector<int64_t> dims; // shape description | |||
std::unique_ptr<uint8_t[]> data; // tensor data | |||
int64_t length; // tensor length | |||
OutputTensorInfo() : data_type(0), dims({}), data(nullptr), length(0) {} | |||
OutputTensorInfo(OutputTensorInfo &&out) | |||
: data_type(out.data_type), dims(out.dims), data(std::move(out.data)), length(out.length) {} | |||
OutputTensorInfo &operator=(OutputTensorInfo &&out) { | |||
if (this != &out) { | |||
data_type = out.data_type; | |||
dims = out.dims; | |||
data = std::move(out.data); | |||
length = out.length; | |||
} | |||
return *this; | |||
} | |||
OutputTensorInfo(const OutputTensorInfo &) = delete; | |||
OutputTensorInfo &operator=(const OutputTensorInfo &) = delete; | |||
}; | |||
// Input/Output tensor info | |||
struct TensorInfo { | |||
uint32_t dataType; // data type | |||
DataDesc data; // tensor data | |||
ShapeDesc shapeInfo; // tensor shape | |||
}; | |||
using Status = uint32_t; | |||
using RunAsyncCallback = std::function<void(Status, std::vector<ge::OutputTensorInfo> &)>; | |||
// for ir build | |||
namespace ir_option { | |||
static const char *const INPUT_FORMAT = "input_format"; | |||
static const char *const INPUT_SHAPE = "input_shape"; | |||
static const char *const OP_NAME_MAP = "op_name_map"; | |||
static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; | |||
static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; | |||
static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); | |||
@@ -235,13 +262,15 @@ static const char *const HEAD_STREAM = ge::HEAD_STREAM.c_str(); | |||
static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); | |||
static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | |||
static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | |||
static const char *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; | |||
// for interface: aclgrphBuildModel | |||
const std::set<std::string> ir_builder_suppported_options = { | |||
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, | |||
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||
AUTO_TUNE_MODE}; | |||
const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE, | |||
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE}; | |||
// for interface: aclgrphBuildInitialize | |||
const std::set<std::string> global_options = {HEAD_STREAM, CORE_TYPE, SOC_VERSION}; | |||
const std::set<std::string> global_options = { | |||
HEAD_STREAM, CORE_TYPE, SOC_VERSION, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||
AUTO_TUNE_MODE, ENABLE_SINGLE_STREAM}; | |||
} // namespace ir_option | |||
} // namespace ge | |||
@@ -55,12 +55,16 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { | |||
graphStatus FindOpByName(const string &name, ge::Operator &op) const; | |||
graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const; | |||
graphStatus GetAllOpName(std::vector<string> &op_name) const; | |||
graphStatus SaveToFile(const string &file_name) const; | |||
graphStatus LoadFromFile(const string &file_name); | |||
const std::string &GetName() const; | |||
/// | |||
/// Set is need train iteration. | |||
/// If set true, it means this graph need to be run iteration some | |||
@@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||
static std::unique_ptr<InferenceContext> Create(); | |||
private: | |||
InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||
explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||
std::shared_ptr<InferenceContextImpl> inference_context_impl_; | |||
}; | |||
} // namespace ge | |||
@@ -44,11 +44,16 @@ | |||
namespace ge { | |||
class OperatorImpl; | |||
class NamedAttrs; | |||
class Graph; | |||
class AttrValue; | |||
using SubgraphBuilder = std::function<Graph(const std::string &name)>; | |||
using OperatorImplPtr = std::shared_ptr<OperatorImpl>; | |||
class Graph; | |||
using GraphBuilderCallback = std::function<Graph()>; | |||
class OpIO; | |||
using OutHandler = std::shared_ptr<OpIO>; | |||
using InHandler = std::shared_ptr<OpIO>; | |||
@@ -69,6 +74,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
using OpBool = bool; | |||
using OpTensor = Tensor; | |||
using OpType = ge::DataType; | |||
using OpNamedAttrs = ge::NamedAttrs; | |||
using OpListInt = std::vector<int64_t>; | |||
using OpListFloat = std::vector<float>; | |||
using OpListString = std::vector<string>; | |||
@@ -77,6 +83,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
using OpBytes = std::vector<uint8_t>; | |||
using OpListListInt = std::vector<std::vector<int64_t>>; | |||
using OpListType = std::vector<ge::DataType>; | |||
using OpListNamedAttrs = std::vector<ge::NamedAttrs>; | |||
Operator() {} | |||
@@ -132,6 +139,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
void SetInferenceContext(const InferenceContextPtr &inference_context); | |||
InferenceContextPtr GetInferenceContext() const; | |||
void SetGraphBuilder(const GraphBuilderCallback &builder); | |||
graphStatus GetGraphBuilder(GraphBuilderCallback &builder) const; | |||
void AddSubgraphName(const string &name); | |||
string GetSubgraphName(int index) const; | |||
graphStatus VerifyAllAttr(bool disable_common_verifier = false); | |||
size_t GetInputsSize() const; | |||
@@ -190,8 +203,21 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
Operator &SetAttr(const string &name, const ge::DataType &attr_value); | |||
graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; | |||
// func type | |||
Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); | |||
graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; | |||
Operator &SetAttr(const string &name, const std::vector<ge::NamedAttrs> &attr_value); | |||
graphStatus GetAttr(const string &name, std::vector<ge::NamedAttrs> &attr_value) const; | |||
void BreakConnect() const; | |||
size_t GetSubgraphNamesCount() const; | |||
std::vector<std::string> GetSubgraphNames() const; | |||
SubgraphBuilder GetSubgraphBuilder(const string &name) const; | |||
Graph GetSubgraph(const string &name) const; | |||
SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; | |||
Graph GetDynamicSubgraph(const string &name, uint32_t index) const; | |||
protected: | |||
void AttrRegister(const string &name, float attr_value); | |||
void AttrRegister(const string &name, const std::vector<float> &attr_value); | |||
@@ -207,6 +233,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
void AttrRegister(const string &name, const std::vector<std::vector<int64_t>> &attr_value); | |||
void AttrRegister(const string &name, const std::vector<ge::DataType> &attr_value); | |||
void AttrRegister(const string &name, const ge::DataType &attr_value); | |||
void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); | |||
void AttrRegister(const string &name, const std::vector<ge::NamedAttrs> &attr_value); | |||
explicit Operator(OperatorImplPtr &&op_impl); | |||
@@ -224,6 +252,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); | |||
void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); | |||
void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); | |||
void RequiredAttrRegister(const string &name); | |||
@@ -235,6 +265,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); | |||
void SubgraphRegister(const std::string &name, bool dynamic); | |||
void SubgraphCountRegister(const std::string &name, uint32_t count); | |||
void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder); | |||
private: | |||
Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | |||
@@ -22,10 +22,11 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "./operator.h" | |||
#include "./operator_factory.h" | |||
#include "./tensor.h" | |||
#include "./types.h" | |||
#include "graph/operator.h" | |||
#include "graph/operator_factory.h" | |||
#include "graph/tensor.h" | |||
#include "graph/types.h" | |||
#include "graph/graph.h" | |||
namespace ge { | |||
using std::function; | |||
@@ -46,6 +47,10 @@ class OpReg { | |||
OpReg &OUTPUT() { return *this; } | |||
OpReg &GRAPH() { return *this; } | |||
OpReg &DYNAMIC_GRAPH() { return *this; } | |||
OpReg &INFER_SHAPE_AND_TYPE() { return *this; } | |||
}; | |||
@@ -191,6 +196,10 @@ class OpReg { | |||
Operator::DynamicInputRegister(#x, num, isPushBack); \ | |||
return *this; \ | |||
} \ | |||
_THIS_TYPE &create_dynamic_input_byindex_##x(unsigned int num, size_t index) { \ | |||
Operator::DynamicInputRegisterByIndex(#x, num, index); \ | |||
return *this; \ | |||
} \ | |||
TensorDesc get_dynamic_input_desc_##x(unsigned int index) const { return Operator::GetDynamicInputDesc(#x, index); } \ | |||
graphStatus update_dynamic_input_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \ | |||
return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ | |||
@@ -229,6 +238,51 @@ class OpReg { | |||
void __dy_output_##x() { \ | |||
(void)OpReg() | |||
#define GRAPH(x) \ | |||
N(); \ | |||
__graph_##x(); \ | |||
} \ | |||
\ | |||
public: \ | |||
static const string name_graph_##x() { return #x; } \ | |||
SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \ | |||
_THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ | |||
Operator::SetSubgraphBuilder(#x, 0, v); \ | |||
return *this; \ | |||
} \ | |||
Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \ | |||
\ | |||
private: \ | |||
void __graph_##x() { \ | |||
Operator::SubgraphRegister(#x, false); \ | |||
Operator::SubgraphCountRegister(#x, 1); \ | |||
(void)OpReg() | |||
#define DYNAMIC_GRAPH(x) \ | |||
N(); \ | |||
__graph_##x(); \ | |||
} \ | |||
\ | |||
public: \ | |||
static const string name_graph_##x() { return #x; } \ | |||
_THIS_TYPE &create_dynamic_subgraph_##x(unsigned int num) { \ | |||
Operator::SubgraphCountRegister(#x, num); \ | |||
return *this; \ | |||
} \ | |||
SubgraphBuilder get_dynamic_subgraph_builder_##x(unsigned int index) const { \ | |||
return Operator::GetDynamicSubgraphBuilder(#x, index); \ | |||
} \ | |||
Graph get_dynamic_subgraph_##x(unsigned int index) const { return Operator::GetDynamicSubgraph(#x, index); } \ | |||
_THIS_TYPE &set_dynamic_subgraph_builder_##x(unsigned int index, const SubgraphBuilder &v) { \ | |||
Operator::SetSubgraphBuilder(#x, index, v); \ | |||
return *this; \ | |||
} \ | |||
\ | |||
private: \ | |||
void __graph_##x() { \ | |||
Operator::SubgraphRegister(#x, true); \ | |||
(void)OpReg() | |||
#define PASTE(g_register, y) g_register##y | |||
#define __OP_END_IMPL__(x, y) \ | |||
N(); \ | |||
@@ -21,6 +21,7 @@ | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include <utility> | |||
#include "./ge_error_codes.h" | |||
#include "./types.h" | |||
@@ -62,6 +63,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { | |||
void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); | |||
Shape GetShape() const; | |||
void SetShape(const Shape &shape); | |||
// set shape with -2, it stand for unknown shape | |||
graphStatus SetUnknownDimNumShape(); | |||
// for unknown shape | |||
graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range); | |||
graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const; | |||
Format GetFormat() const; | |||
void SetFormat(Format format); | |||
@@ -23,7 +23,9 @@ | |||
namespace ge { | |||
static const int64_t UNKNOWN_DIM = -1; | |||
static const int64_t UNKNOWN_DIM_NUM = -2; | |||
static const std::vector<int64_t> UNKNOWN_SHAPE = {0}; | |||
static const std::vector<int64_t> UNKNOWN_RANK = {-2}; | |||
#ifdef HOST_VISIBILITY | |||
#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) | |||
@@ -140,10 +142,19 @@ enum Format { | |||
FORMAT_NC, | |||
FORMAT_DHWNC, | |||
FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format | |||
FORMAT_FRACTAL_ZN_LSTM, | |||
FORMAT_RESERVED, | |||
FORMAT_ALL | |||
}; | |||
// for unknown shape op type | |||
enum UnknowShapeOpType { | |||
DEPEND_IN_SHAPE = 1, // op out shape get by input shape | |||
DEPEND_CONST_VALUE = 2, // op out shape get by const op value | |||
DEPEND_SHAPE_RANGE = 3, // op out shape get by range | |||
DEPEND_COMPUTE = 4 // op out shape get by totally computing | |||
}; | |||
struct TensorDescInfo { | |||
Format format_ = FORMAT_RESERVED; // tbe op register support format | |||
DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype | |||
@@ -58,12 +58,18 @@ Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | |||
Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | |||
std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value, | |||
int in_pos = -1, int out_pos = -1); | |||
Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function<int(int data_index)> &input, | |||
const std::function<int(int netoutput_index)> &output); | |||
Status AutoMappingSubgraphIndex(const ge::Graph &graph, | |||
const std::function<Status(int data_index, int &parent_input_index)> &input, | |||
const std::function<Status(int netoutput_index, int &parent_output_index)> &output); | |||
using google::protobuf::Message; | |||
class OpRegistrationDataImpl; | |||
using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | |||
using FusionParseParamFunc = | |||
std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | |||
using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>; | |||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
public: | |||
@@ -81,6 +87,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | |||
OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); | |||
OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | |||
OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | |||
@@ -93,6 +101,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
domi::FrameworkType GetFrameworkType() const; | |||
ParseParamFunc GetParseParamFn() const; | |||
FusionParseParamFunc GetFusionParseParamFn() const; | |||
ParseSubgraphFunc GetParseSubgraphPostFn() const; | |||
private: | |||
std::shared_ptr<OpRegistrationDataImpl> impl_; | |||
@@ -116,27 +125,5 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||
namespace ge { | |||
using OpRegistrationData = domi::OpRegistrationData; | |||
using OpReceiver = domi::OpReceiver; | |||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||
public: | |||
HostCpuOp() = default; | |||
virtual ~HostCpuOp() = default; | |||
virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||
std::map<std::string, Tensor> &outputs) = 0; | |||
}; | |||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||
public: | |||
HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||
}; | |||
#define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||
static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||
::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||
} // namespace ge | |||
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ |
@@ -51,24 +51,24 @@ inline pid_t GetTid() { | |||
return tid; | |||
} | |||
#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = domi::GetCurrentTimestap() | |||
#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | |||
#define GE_TIMESTAMP_END(stage, stage_name) \ | |||
do { \ | |||
uint64_t endUsec_##stage = domi::GetCurrentTimestap(); \ | |||
uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ | |||
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | |||
(endUsec_##stage - startUsec_##stage)); \ | |||
} while (0); | |||
#define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||
uint64_t startUsec_##stage = domi::GetCurrentTimestap(); \ | |||
uint64_t call_num_of##stage = 0; \ | |||
#define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||
uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ | |||
uint64_t call_num_of##stage = 0; \ | |||
uint64_t time_of##stage = 0 | |||
#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = domi::GetCurrentTimestap()) | |||
#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) | |||
#define GE_TIMESTAMP_ADD(stage) \ | |||
time_of##stage += domi::GetCurrentTimestap() - startUsec_##stage; \ | |||
#define GE_TIMESTAMP_ADD(stage) \ | |||
time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ | |||
call_num_of##stage++ | |||
#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ | |||
@@ -22,7 +22,6 @@ | |||
#include "cce/cce_def.hpp" | |||
#include "common/string_util.h" | |||
#include "common/util.h" | |||
#include "dlog/log.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "ge/ge_api_error_codes.h" | |||
@@ -30,7 +29,7 @@ using cce::CC_STATUS_SUCCESS; | |||
using cce::ccStatus_t; | |||
#if !defined(__ANDROID__) && !defined(ANDROID) | |||
#define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__) | |||
#define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||
#else | |||
#include <android/log.h> | |||
#if defined(BUILD_VERSION_PERF) | |||
@@ -103,17 +102,17 @@ using cce::ccStatus_t; | |||
} while (0); | |||
// If expr is not true, print the log and return the specified status | |||
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||
do { \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
std::string msg; \ | |||
(void)msg.append(domi::StringUtils::FormatString(__VA_ARGS__)); \ | |||
(void)msg.append( \ | |||
domi::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
DOMI_LOGE("%s", msg.c_str()); \ | |||
return _status; \ | |||
} \ | |||
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||
do { \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
std::string msg; \ | |||
(void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | |||
(void)msg.append( \ | |||
ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
DOMI_LOGE("%s", msg.c_str()); \ | |||
return _status; \ | |||
} \ | |||
} while (0); | |||
// If expr is not true, print the log and return the specified status | |||
@@ -152,7 +152,6 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | |||
@@ -25,6 +25,7 @@ | |||
#include "common/fmk_error_codes.h" | |||
#include "ge/ge_api_error_codes.h" | |||
#include "external/graph/types.h" | |||
#include "external/ge/ge_api_types.h" | |||
namespace ge { | |||
enum RuntimeType { HOST = 0, DEVICE = 1 }; | |||
@@ -130,7 +131,8 @@ class ModelListener { | |||
/// @param [in] data_index Index of the input_data | |||
/// @param [in] resultCode Execution results | |||
/// | |||
virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code) = 0; | |||
virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code, | |||
std::vector<ge::OutputTensorInfo> &outputs) = 0; | |||
}; | |||
// OMM configuration item | |||
@@ -147,6 +149,8 @@ struct Options { | |||
std::string rankTableFile; | |||
int32_t ge_hccl_flag = 0; | |||
int32_t physical_device_id; | |||
std::string profiling_mode; | |||
std::string profiling_options; | |||
}; | |||
// Profiling info of task | |||
@@ -20,7 +20,7 @@ | |||
#include <gflags/gflags.h> | |||
#include <string> | |||
namespace domi { | |||
namespace ge { | |||
class GflagsUtils { | |||
public: | |||
static bool IsSetCommandTrue(const char *name) { | |||
@@ -66,6 +66,6 @@ class GflagsUtils { | |||
} | |||
} | |||
}; | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_GFLAGS_UTIL_H_ |
@@ -26,7 +26,7 @@ | |||
#include "graph/model.h" | |||
#include "model/ge_model.h" | |||
namespace domi { | |||
namespace ge { | |||
class ModelHelper { | |||
public: | |||
ModelHelper() = default; | |||
@@ -38,7 +38,7 @@ class ModelHelper { | |||
Status LoadModel(const ge::ModelData& model_data); | |||
Status GetModelBufferData(ge::ModelBufferData& model); | |||
ModelFileHeader* GetFileHeader() { return file_header_; } | |||
const ModelFileHeader* GetFileHeader() const { return file_header_; } | |||
GeModelPtr GetGeModel(); | |||
void SetSaveMode(bool val) { is_offline_ = val; } | |||
@@ -65,9 +65,8 @@ class ModelHelper { | |||
Status LoadTask(OmFileLoadHelper& om_load_helper); | |||
Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | |||
Status ReleaseLocalModelData() noexcept; | |||
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | |||
const uint8_t* data, size_t size); | |||
}; | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ |
@@ -26,8 +26,10 @@ | |||
#include "framework/common/ge_types.h" | |||
using ProcParam = struct PROC_PARAM; | |||
using std::string; | |||
using std::vector; | |||
namespace domi { | |||
namespace ge { | |||
struct ModelPartition { | |||
ModelPartitionType type; | |||
uint8_t *data = 0; | |||
@@ -88,5 +90,5 @@ class OmFileSaveHelper { | |||
ModelFileHeader model_header_; | |||
OmFileContext context_; | |||
}; | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ |
@@ -30,7 +30,7 @@ | |||
using std::vector; | |||
namespace domi { | |||
namespace ge { | |||
// Size of RC memory alignment, 2M | |||
constexpr size_t ALIGN_SIZE = 2097152; | |||
@@ -118,6 +118,6 @@ class L2CacheOptimize { | |||
bool Cross(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | |||
bool Connect(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | |||
}; | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ |
@@ -1,810 +0,0 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||
#define INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||
#include <string> | |||
#include "framework/common/fmk_types.h" | |||
namespace domi { | |||
// Public Attribute | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAME; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHT_NAME; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALPHA; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BETA; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS_TERM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SCALE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WINDOWS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GLOBAL_POOLING; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CEIL_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELU_FLAG; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALGO; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_K; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_NORM_REGION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_LOCAL_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_ALPHA; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_BETA; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROADCAST; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TIDX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TPADDINGS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_W; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_W; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TMULTIPLES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTIPLES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_T; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_N; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TSHAPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAN_OPT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AIPP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string NEW_AIPP_CONV_OP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
static const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||
static const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_BATCH_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_OP_DEF; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_TENSOR_DESC; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INFERRED_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHTS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DIM_ALIGN; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||
// To be deleted | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_TO_BE_DELETED; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_LOC_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_CONF_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_OCR_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | |||
// Refinedet | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_LOC_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_CONF_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIORBOX_CONCAT; | |||
// _Arg | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INDEX; | |||
// _RetVal | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETVAL_ATTR_NAME_INDEX; | |||
// Data | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DATA_ATTR_NAME_DATA_TYPE; | |||
// Send | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SEND_ATTR_EVENT_ID; | |||
// Recv | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RECV_ATTR_EVENT_ID; | |||
// convolution | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_COEF; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATIONS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ALGO; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_GROUP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_STRIDE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_DILATION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_NUM_OUTPUT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_KERNEL; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_FILTER; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BIAS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_RELU_FLAG; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ADJ; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_TARGET_SHAPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BEFORE_PAD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_HAS_BIAS; | |||
// Pooling | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAN_OPT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_GLOBAL_POOLING; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_WINDOW; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_STRIDE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_CEIL_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_DATA_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_BEFORE_PAD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAME_ALGO; | |||
// Eltwise | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_COEFF; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_WEIGHT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_RELU_FLAG; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_ALPHA; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_BETA; | |||
// BatchNorm | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_EPSILON; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_SCALE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_BIAS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||
// Huberloss | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HUBER_LOSS_ATTR_DELTA; | |||
// SSDRealDivTileMul | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||
// SSDSumMulRealDivMean | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||
SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||
/// ConcatFive2Four | |||
/// ConcatFour2Five | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_CLASS_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TRANS_FOR_LOSS_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOX_TYPE_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_HIGH; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_WIDTH; | |||
// Scale | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_SCALE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_BIAS; | |||
// FullConnection | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_FILTER; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_BIAS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_RELU_FLAG; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_ATTR_NAME_ALGO; | |||
// SoftmaxOpParams | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_ALGO; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_MODE; | |||
// SparseSoftmaxCrossEntropy | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; | |||
// Activation | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_COEF; | |||
// Concat | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_ATTR_NAME_AXIS; | |||
// Const | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | |||
// Roipooling | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_W; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; | |||
// DetectionOutput | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_TOP_K; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_W; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; | |||
// Ssd DetectionOutput | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ETA; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||
DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; | |||
// Refinedet DetectionOutput | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; | |||
// yolo DetectionOutput | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ClASSES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BIASES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_RELATIVE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; | |||
// DetectionPostprocess | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; | |||
// Spatialtransfrom | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; | |||
// Proposal | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_RATIO; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_SCALE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_W; | |||
// Softmax | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_AXIS; | |||
// Permute | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_ORDER; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_PERM; | |||
// SSD Normalize | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_EPS; | |||
// Flatten | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_AXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_END_AXIS; | |||
// SsdPRIORBOX | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_FLIP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_CLIP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_W; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_W; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_OFFSET; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||
// RefinedetPRIORBOX | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | |||
// PRelu | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PRELU_ATTR_CHANNEL_SHARED; | |||
// Psroi pooling | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_GROUP_SIZE; | |||
// Power | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_POWER; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SCALE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SHIFT; | |||
// Log | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SCALE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SHIFT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_BASE; | |||
// Pack | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PACK_ATTR_NAME_NUM; | |||
// Dynamic stitch | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||
// Unpack | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UNPACK_ATTR_NAME_NUM; | |||
// Gathernd | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TINDICES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TPARAMS; | |||
// Argmax | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_TOPK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_OUTMAX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||
// Upsample | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||
// Relu | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NEGATIVE_SLOPE; | |||
// FreeSpaceExtract | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; | |||
// split | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SLICE_POINT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_NUM_SPLIT; | |||
// Tvm | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_MAGIC; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_BLOCKDIM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_METADATA; | |||
// Squeeze | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_AXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_DIMS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_OP_NAME; | |||
// Stride slice | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_END_MASK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; | |||
// Slice | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_BEGINS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_SIZES; | |||
// Roialign | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SPATIAL_SCALE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_W; | |||
// Generate_rpn_proposal | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; | |||
// Decode_bbox | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DECODE_BBOX_ATTR_DECODECLIP; | |||
// Cast | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_DSTT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_SRCT; | |||
// Fastrcnnn predications | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; | |||
// REORG | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_STRIDE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_REVERSE; | |||
// MERGE | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_DEAD_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_PRENODE_FLAG; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TO_BE_OUTPUT; | |||
static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||
// Concatv2 | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_TIDX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_N; | |||
// SUM | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_TIDX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_AXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_KEEP_DIMS; | |||
// ResizeBilinear | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_HEIGHT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_WIDTH; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_END; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALPHA; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_BETA; | |||
// RetinaNet | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_ANCHOR_FUSION; | |||
// MatMul | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_X; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_W; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_HAS_BIAS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_ATTR_IS_TRAINING; | |||
// Flatten | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_START_AXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_END_AXIS; | |||
// Reshape | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_AXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NUM_AXES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_SHAPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_ALPHA; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_BETA; | |||
// Frameoworkop | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_IN_DATATYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_OUT_DATATYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_N; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_C; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_H; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_W; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_DEPTH_CONV; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_CONV; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BEFORE_PAD; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ANN_MEAN_KEEPDIMS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_PADDINGDS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_CONSTANT_VALUE; | |||
// ConvGradFilter | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; | |||
// ConvGradInput | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; | |||
// Rnn | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_MODE_STATIC; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MUTI_RNN; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CELL_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CNN_RNN; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GRU_CELL; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_HT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_XT_HT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_BATCH_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL_CLIP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_PROJ_CLIP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_ACTIVATE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MAP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_STATE_OUT_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_TIME_MAJOR; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||
// Upsample | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE; | |||
// PadV2 | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PADS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_T; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||
// MirrorPad | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PADS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||
// Filler | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_VALUE; | |||
// Shufflechannel | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHUFFLE_CHANNEL_GROUP; | |||
// TopKV2 | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TOPKV2_ATTR_K; | |||
// Calibaration | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_H_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_W_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_TOP_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_BOTTOM_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_RIGHT_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_LEFT_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_CONST; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GROUP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_EPSILON; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_POOLING_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CLASS_NUM; | |||
// model | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TARGET_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_STREAM_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_LABEL_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | |||
// Public Attribute | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IMPLY_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BYTE_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_INFERENCE_ID; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_OPDEF; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_SCOPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OPATTR; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELUFLAG; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SEQLEN_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_X_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CONT_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_XSTATIC_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_MINI; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_TINY; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_LITE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_LABEL; | |||
// L2_normalize | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_AXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_EPS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_WINDOW; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_NAN_OP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||
// HCOM | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_ROOT_RANK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCE_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_RANK_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCTION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_GROUP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SR_TAG; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SRC_RANK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DEST_RANK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SHAPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DATA_TYPE; | |||
// Log time stamp | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_LOGID; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_NOTIFY; | |||
// SpaceToDepth/DepthToSpace | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BLOCK_SIZE; | |||
// SparseSoftmaxCrossEntropyWithLogits | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||
// MaxPoolGradWithArgmax | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||
// AvgPoolGrad | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||
// Pad | |||
extern const std::string ATTR_PAD_FORMAT; | |||
// Varible | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_NAME; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_4D_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_5D_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DATA_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_NAME; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHAPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HALF_VAR_NAME_END; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_CONTAINER; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHARED_NAME; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DTYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_ADDR_OFFSET; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SRC_VAR_NAME; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_SAVE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_RESTORE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_SRC_VAR_NAME; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||
// Assign | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ASSIGN_VALIDATE_SHAPE; | |||
// ShapeN | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_N; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_IN_TYPE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_OUT_TYPE; | |||
// Space2bacth batch2space | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_BLOCK; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_PADDING; | |||
// Depth_to_space space_to_depth | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
// FakeQuantWithMinMaxVars | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||
// Mobilenet_ssd_conv_fusion | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||
// Lsh project | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSH_PROJ_TYPE; | |||
// Control flow | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||
// GatherV2 attr def | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TAXIS; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TINDICES; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||
// Reshape attr def | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||
// Axis attr def | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS_ORG_OP; | |||
// The node link with SparseSoftmaxCrossEntropyWithLogits | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LINK_WITH_SPARE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||
// For constant folding | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||
} // namespace domi | |||
#endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ |
@@ -21,11 +21,17 @@ | |||
#include <unordered_map> | |||
#include <string> | |||
#include "common/op/attr_define.h" | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "proto/om.pb.h" | |||
namespace domi { | |||
using domi::AttrDef; | |||
using domi::AttrDef_ListValue; | |||
using domi::ModelDef; | |||
using domi::NamedAttrs; | |||
using domi::OpDef; | |||
namespace ge { | |||
using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; | |||
using AttrDefPair = ::google::protobuf::MapPair<std::string, domi::AttrDef>; | |||
@@ -150,6 +156,6 @@ bool GetAttrDefListValue(const std::string &key, int idx, int32_t *value, const | |||
bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); | |||
bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); | |||
bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ |
@@ -62,6 +62,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; | |||
class OpUtils { | |||
public: | |||
/// | |||
@@ -22,7 +22,7 @@ | |||
#include <math.h> | |||
#include <stdint.h> | |||
namespace domi { | |||
namespace ge { | |||
// general | |||
const float DEFAULT_ALPHA_VALUE = 1.0; | |||
const float DEFAULT_BETA_VALUE = 0.0; | |||
@@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; | |||
// Shufflechannel | |||
const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ |
@@ -20,7 +20,7 @@ | |||
#include <set> | |||
#include <string> | |||
namespace domi { | |||
namespace ge { | |||
class OpTypeContainer { | |||
public: | |||
static OpTypeContainer *Instance() { | |||
@@ -57,6 +57,6 @@ class OpTypeRegistrar { | |||
const OpTypeRegistrar g_##var_name##_reg(str_name); | |||
#define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ |
@@ -25,10 +25,10 @@ | |||
/// MAKE_GUARD([&] { Release Resource 1 }) | |||
/// Acquire Resource 2 | |||
// MAKE_GUARD([&] { Release Resource 2 }) | |||
#define GE_MAKE_GUARD(var, callback) domi::ScopeGuard make_guard_##var(callback) | |||
#define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) | |||
#define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | |||
namespace domi { | |||
namespace ge { | |||
class ScopeGuard { | |||
public: | |||
// Noncopyable | |||
@@ -55,6 +55,6 @@ class ScopeGuard { | |||
std::function<void()> on_exit_scope_; | |||
bool dismissed_; | |||
}; | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ |
@@ -25,7 +25,7 @@ | |||
#include <string> | |||
#include <vector> | |||
namespace domi { | |||
namespace ge { | |||
class StringUtils { | |||
public: | |||
static std::string &Ltrim(std::string &s) { | |||
@@ -151,6 +151,6 @@ class StringUtils { | |||
return ret > 0 ? buffer : ""; | |||
} | |||
}; | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ |
@@ -26,6 +26,7 @@ | |||
#include <string> | |||
#include <utility> | |||
#include <vector> | |||
#include "framework/common/fmk_error_codes.h" | |||
#include "framework/common/fmk_types.h" | |||
#include "framework/common/op_types.h" | |||
@@ -46,9 +47,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_A | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | |||
} // namespace ge | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; | |||
namespace domi { | |||
// Supported public properties name | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path | |||
@@ -68,14 +68,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFIL | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; | |||
/// @brief Data structure definition related to task sinking | |||
/// Build model | |||
enum BuildMode { | |||
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||
GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) | |||
}; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; | |||
@@ -341,8 +333,9 @@ REGISTER_OPTYPE_DECLARE(END, "End"); | |||
REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); | |||
REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); | |||
REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); | |||
REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") | |||
/***************ANN dedicated operator *************************/ | |||
// ANN dedicated operator | |||
REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); | |||
REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); | |||
REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); | |||
@@ -359,7 +352,7 @@ REGISTER_OPTYPE_DECLARE(ANN_QUANTIZE, "AnnQuant"); | |||
REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); | |||
REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); | |||
/********************Training operator ***********************/ | |||
// Training operator | |||
REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); | |||
REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); | |||
REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); | |||
@@ -438,11 +431,13 @@ REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); | |||
REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | |||
REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | |||
REGISTER_OPTYPE_DECLARE(LogTimeStamp, "LogTimeStamp"); | |||
REGISTER_OPTYPE_DECLARE(PARALLELCONCATSTART, "_ParallelConcatStart"); | |||
REGISTER_OPTYPE_DECLARE(CONSTANTOP, "Constant"); | |||
REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); | |||
REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | |||
REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | |||
REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | |||
REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||
REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | |||
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | |||
REGISTER_OPTYPE_DECLARE(SEND, "Send"); | |||
@@ -450,6 +445,7 @@ REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||
REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | |||
REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | |||
REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); | |||
REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | |||
REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | |||
@@ -828,9 +824,6 @@ static constexpr int32_t PARTITION_TYPE_TASK_INFO = 2; | |||
// number of partitions in the current model | |||
static constexpr uint32_t PARTITION_SIZE = 4; | |||
#define SIZE_OF_MODEL_PARTITION_TABLE(table) \ | |||
(sizeof(domi::ModelPartitionTable) + sizeof(domi::ModelPartitionMemInfo) * (table).num) | |||
enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS }; | |||
struct ModelPartitionMemInfo { | |||
@@ -844,6 +837,8 @@ struct ModelPartitionTable { | |||
ModelPartitionMemInfo partition[0]; | |||
}; | |||
#define SIZE_OF_MODEL_PARTITION_TABLE(table) (sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * (table).num) | |||
static constexpr int32_t PTHREAD_CREAT_SUCCESS = 0; // pthread_creat success | |||
// Filter format | |||
@@ -975,8 +970,8 @@ typedef enum tagDomiNanPropagation { | |||
// mode of cropandresize | |||
typedef enum tagDomiCropAndResizeMode { | |||
DOMI_RESIZE_METHOD_BILINEAR = 0, /**< resize bilinear */ | |||
DOMI_RESIZE_METHOD_NEAREST, /**< resize nearest */ | |||
DOMI_RESIZE_METHOD_BILINEAR = 0, // resize bilinear | |||
DOMI_RESIZE_METHOD_NEAREST, // resize nearest | |||
DOMI_RESIZE_RESERVED | |||
} domiCropAndResizeMode_t; | |||
@@ -1063,6 +1058,15 @@ struct BasicInfo { | |||
uint32_t total_size; // total memory size | |||
}; | |||
#pragma pack() // Cancels single-byte alignment | |||
} // namespace ge | |||
namespace domi { | |||
/// @brief Data structure definition related to task sinking | |||
enum BuildMode { | |||
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||
GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) | |||
}; | |||
} // namespace domi | |||
#endif // INC_FRAMEWORK_COMMON_TYPES_H_ |
@@ -30,12 +30,12 @@ | |||
#include "framework/common/ge_inner_error_codes.h" | |||
#include "mmpa/mmpa_api.h" | |||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||
do { \ | |||
if (size <= 0) { \ | |||
DOMI_LOGE(param[#size] is not a positive number); \ | |||
return PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||
do { \ | |||
if (size <= 0) { \ | |||
DOMI_LOGE("param[%s] is not a positive number", #size); \ | |||
return PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | |||
@@ -44,7 +44,7 @@ | |||
if (!b) { \ | |||
exec_expr; \ | |||
} \ | |||
}; | |||
} | |||
// new ge marco | |||
// Encapsulate common resource releases | |||
@@ -113,101 +113,101 @@ | |||
} while (0) | |||
// Check if the parameter is null. If yes, return PARAM_INVALID and record the error | |||
#define GE_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
// Check if the parameter is null. If yes, just return and record the error | |||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
return; \ | |||
} \ | |||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
return; \ | |||
} \ | |||
} while (0) | |||
// Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | |||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
exec_expr; \ | |||
} \ | |||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
exec_expr; \ | |||
} \ | |||
} while (0) | |||
// Check whether the parameter is null. If yes, return directly and record the error log | |||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
return; \ | |||
} \ | |||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
return; \ | |||
} \ | |||
} while (0) | |||
// Check if the parameter is null. If yes, return false and record the error log | |||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
return false; \ | |||
} \ | |||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
return false; \ | |||
} \ | |||
} while (0) | |||
// Check if the parameter is out of bounds | |||
#define GE_CHECK_SIZE(size) \ | |||
do { \ | |||
if (size == 0) { \ | |||
DOMI_LOGE(param[#size] is out of range); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_SIZE(size) \ | |||
do { \ | |||
if (size == 0) { \ | |||
DOMI_LOGE("param[%s] is out of range", #size); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
// Check if the container is empty | |||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
do { \ | |||
if (vector.empty()) { \ | |||
DOMI_LOGE(param[#vector] is empty !); \ | |||
return ge::FAILED; \ | |||
} \ | |||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
do { \ | |||
if (vector.empty()) { \ | |||
DOMI_LOGE("param[%s] is empty!", #vector); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} while (0) | |||
// Check if the value on the left is greater than or equal to the value on the right | |||
#define GE_CHECK_GE(lhs, rhs) \ | |||
do { \ | |||
if (lhs < rhs) { \ | |||
DOMI_LOGE(param[#lhs] is less than[#rhs]); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_GE(lhs, rhs) \ | |||
do { \ | |||
if (lhs < rhs) { \ | |||
DOMI_LOGE("param[%s] is less than[%s]", #lhs, #rhs); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
// Check if the value on the left is less than or equal to the value on the right | |||
#define GE_CHECK_LE(lhs, rhs) \ | |||
do { \ | |||
if (lhs > rhs) { \ | |||
DOMI_LOGE(param[#lhs] is greater than[#rhs]); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_LE(lhs, rhs) \ | |||
do { \ | |||
if (lhs > rhs) { \ | |||
DOMI_LOGE("param[%s] is greater than[%s]", #lhs, #rhs); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
#define GE_DELETE_NEW_SINGLE(var) \ | |||
{ \ | |||
do { \ | |||
if (var != nullptr) { \ | |||
delete var; \ | |||
var = nullptr; \ | |||
} \ | |||
}; | |||
} while (0) | |||
#define GE_DELETE_NEW_ARRAY(var) \ | |||
{ \ | |||
do { \ | |||
if (var != nullptr) { \ | |||
delete[] var; \ | |||
var = nullptr; \ | |||
} \ | |||
}; | |||
} while (0) | |||
/** | |||
* @ingroup domi_common | |||
@@ -220,7 +220,7 @@ static constexpr int32_t OM_PROTO_VERSION = 2; | |||
*/ | |||
#define CEIL(N, n) (((N) + (n)-1) / (n)) | |||
namespace domi { | |||
namespace ge { | |||
using google::protobuf::Message; | |||
/// | |||
@@ -373,7 +373,7 @@ std::string RealPath(const char *path); | |||
/// @param [in] file_path path of input file | |||
/// @param [out] result | |||
/// | |||
bool CheckInputPathValid(const std::string &file_path); | |||
bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param = ""); | |||
/// | |||
/// @ingroup domi_common | |||
@@ -381,7 +381,7 @@ bool CheckInputPathValid(const std::string &file_path); | |||
/// @param [in] file_path path of output file | |||
/// @param [out] result | |||
/// | |||
bool CheckOutputPathValid(const std::string &file_path); | |||
bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param = ""); | |||
/// | |||
/// @ingroup domi_common | |||
@@ -390,6 +390,6 @@ bool CheckOutputPathValid(const std::string &file_path); | |||
/// @param [out] result | |||
/// | |||
bool ValidateStr(const std::string &filePath, const std::string &mode); | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_UTIL_H_ |
@@ -47,6 +47,8 @@ class GeGenerator { | |||
Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model); | |||
Status GenerateInfershapeGraph(const Graph &graph); | |||
/// | |||
/// @ingroup ge | |||
/// @brief: Build single OP in Model. | |||
@@ -33,7 +33,7 @@ class MemoryAssigner { | |||
MemoryAssigner &operator=(const MemoryAssigner &) = delete; | |||
Status AssignMemory(bool is_loop_graph, size_t &mem_offset); | |||
Status AssignMemory(bool is_loop_graph, size_t &mem_offset, size_t &zero_copy_mem_size); | |||
private: | |||
ge::ComputeGraphPtr compute_graph_; | |||
@@ -28,21 +28,27 @@ | |||
#include "framework/common/types.h" | |||
#include "register/register_fmk_types.h" | |||
using domi::DOMI_TENSOR_ND; | |||
using domi::DOMI_TENSOR_RESERVED; | |||
using domi::domiTensorFormat_t; | |||
using domi::FMK_TYPE_RESERVED; | |||
using domi::FrameworkType; | |||
using std::map; | |||
using std::string; | |||
using std::unordered_map; | |||
using std::vector; | |||
namespace domi { | |||
namespace ge { | |||
/** | |||
* @ingroup domi_omg | |||
* @brief run model | |||
*/ | |||
enum RunMode { | |||
GEN_OM_MODEL = 0, // generate offline model file | |||
MODEL_TO_JSON = 1, // convert to JSON file | |||
ONLY_PRE_CHECK = 3, // only for pre-check | |||
PBTXT_TO_JSON = 5 // pbtxt to json | |||
GEN_OM_MODEL = 0, // generate offline model file | |||
MODEL_TO_JSON = 1, // convert to JSON file | |||
MODEL_TO_JSON_WITH_SHAPE = 2, // convert to json file with shape | |||
ONLY_PRE_CHECK = 3, // only for pre-check | |||
PBTXT_TO_JSON = 5 // pbtxt to json | |||
}; | |||
/// | |||
@@ -93,7 +99,7 @@ struct OmgContext { | |||
std::string ddk_version; | |||
// preferential format used by the entire network | |||
domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | |||
FrameworkType type = FMK_TYPE_RESERVED; | |||
domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | |||
RunMode run_mode = ONLY_PRE_CHECK; | |||
bool train_flag = false; | |||
// whether to use FP16 high precision | |||
@@ -102,23 +108,25 @@ struct OmgContext { | |||
std::string output_type; | |||
// Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | |||
// network require special processing based on the specific network. | |||
// e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope | |||
// fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The | |||
// convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. | |||
// network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module | |||
// is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the | |||
// FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape | |||
// operator needs to be deleted for the convolution kernel. | |||
std::string net_name; | |||
// Whether to use dynamic batch size or dynamic image size | |||
bool is_dynamic_input = false; | |||
std::string dynamic_batch_size; | |||
std::string dynamic_image_size; | |||
}; | |||
} // namespace ge | |||
namespace domi { | |||
/** | |||
* @ingroup domi_omg | |||
* @brief get OMG context | |||
* @return OmgContext context | |||
*/ | |||
OmgContext &GetContext(); | |||
ge::OmgContext &GetContext(); | |||
struct TEBinInfo { | |||
// It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. | |||
@@ -26,7 +26,7 @@ | |||
#include "common/string_util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
namespace domi { | |||
namespace ge { | |||
class PlatformVersionManager { | |||
public: | |||
PlatformVersionManager() = delete; | |||
@@ -40,6 +40,6 @@ class PlatformVersionManager { | |||
return SUCCESS; | |||
} | |||
}; // class PlatformManager | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_OMG_VERSION_H_ |
@@ -86,16 +86,16 @@ class _GeSerializable { | |||
} | |||
template <class T, class... Args> | |||
static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) { | |||
static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { | |||
GeAttrValue itemVal = SaveItemAsAttrValue(item); | |||
(void)namedAttrs.SetAttr(itemName, itemVal); | |||
SaveItem(namedAttrs, args...); | |||
} | |||
static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) {} | |||
static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {} | |||
template <class T, class... Args> | |||
static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) { | |||
static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { | |||
auto itemVal = namedAttrs.GetItem(itemName); | |||
auto status = LoadItemFromAttrValue(item, itemVal); | |||
if (status != GRAPH_SUCCESS) { | |||
@@ -104,7 +104,9 @@ class _GeSerializable { | |||
return LoadItem(namedAttrs, args...); | |||
} | |||
static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } | |||
static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) { | |||
return GRAPH_SUCCESS; | |||
} | |||
}; | |||
#define _GE_FI(a) #a, a | |||
@@ -171,13 +173,13 @@ class _GeSerializable { | |||
\ | |||
private: \ | |||
ge::graphStatus Save(GeAttrValue &ar) const { \ | |||
GeAttrValue::NamedAttrs named_attrs; \ | |||
GeAttrValue::NAMED_ATTRS named_attrs; \ | |||
_GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ | |||
return ar.SetValue<GeAttrValue::NamedAttrs>(named_attrs); \ | |||
return ar.SetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \ | |||
} \ | |||
ge::graphStatus Load(const GeAttrValue &ar) { \ | |||
GeAttrValue::NamedAttrs named_attrs; \ | |||
ge::graphStatus status = ar.GetValue<GeAttrValue::NamedAttrs>(named_attrs); \ | |||
GeAttrValue::NAMED_ATTRS named_attrs; \ | |||
ge::graphStatus status = ar.GetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \ | |||
if (status != GRAPH_SUCCESS) { \ | |||
return status; \ | |||
} \ | |||
@@ -83,6 +83,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
// AddNode with NodePtr | |||
NodePtr AddNode(NodePtr node); | |||
NodePtr AddNode(OpDescPtr op); | |||
NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize. | |||
NodePtr AddNodeFront(NodePtr node); | |||
NodePtr AddNodeFront(const OpDescPtr &op); | |||
NodePtr AddInputNode(NodePtr node); | |||
@@ -236,8 +237,9 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
std::deque<NodePtr> &stack); | |||
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | |||
std::map<string, NodePtr> &breadth_node_map); | |||
graphStatus TopologicalSortingSubgraph(); | |||
graphStatus TopologicalSortingGraph(); | |||
graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | |||
Vistor<NodePtr> AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const; | |||
size_t GetInEdgeSize(const NodePtr &node); | |||
size_t GetOutEdgeSize(const NodePtr &node); | |||
graphStatus RemoveExtraOutEdge(const NodePtr &node); | |||
@@ -32,6 +32,12 @@ namespace ge { | |||
#define GE_FUNC_DEV_VISIBILITY | |||
#endif | |||
// Public attribute | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; | |||
@@ -58,6 +64,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | |||
@@ -74,8 +82,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | |||
// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | |||
// ATTR_NAME_WEIGHTS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | |||
@@ -123,6 +130,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | |||
@@ -140,12 +154,24 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||
// to be deleted | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; | |||
@@ -158,15 +184,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | |||
// _Arg | |||
@@ -255,7 +281,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||
// Huberloss | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; | |||
// SSDRealDivTileMul | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||
// SSDSumMulRealDivMean | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||
/// ConcatFive2Four | |||
/// ConcatFour2Five | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; | |||
// Scale | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | |||
@@ -292,7 +340,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||
// Roipooling | |||
@@ -305,6 +352,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI | |||
// DetectionOutput | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | |||
@@ -363,6 +411,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ | |||
// Permute | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; | |||
// SSD Normalize | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | |||
@@ -403,9 +452,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | |||
// Log | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; | |||
// Pack | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | |||
// Dynamic stitch | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||
// Unpack | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | |||
// Gathernd | |||
@@ -414,8 +469,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND | |||
// Argmax | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||
// Upsample | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||
// Relu | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | |||
@@ -486,6 +549,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | |||
static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||
// ENTER | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | |||
@@ -511,6 +575,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | |||
// RetinaNet | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; | |||
// MatMul | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | |||
@@ -559,10 +626,30 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||
// Upsample | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | |||
// PadV2 | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||
// MirrorPad | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||
// Filler | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | |||
@@ -583,36 +670,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | |||
@@ -627,24 +684,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HUGE_STREAM_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||
// Public attribute | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | |||
@@ -678,6 +731,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_T | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | |||
@@ -696,6 +751,161 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||
// L2_normalize | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||
// HCOM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||
// Log time stamp | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; | |||
// SpaceToDepth/DepthToSpace | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; | |||
// SparseSoftmaxCrossEntropyWithLogits | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||
// MaxPoolGradWithArgmax | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||
// AvgPoolGrad | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||
// Varible | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||
// Assign | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; | |||
// ShapeN | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; | |||
// Space2bacth batch2space | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; | |||
// Depth_to_space space_to_depth | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
// FakeQuantWithMinMaxVars | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||
// Mobilenet_ssd_conv_fusion | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||
// Lsh project | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; | |||
// Control flow | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||
// GatherV2 attr def | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||
// Reshape attr def | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||
// Axis attr def | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; | |||
// The node link with SparseSoftmaxCrossEntropyWithLogits | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||
// For constant folding | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||
// Used for mark the active label list to find stream of activated node | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | |||
@@ -708,7 +918,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
// Control flow | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | |||
@@ -722,6 +931,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
// Function Op | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_CONST_TYPE; | |||
// Used for mark the active node is for loop, type:bool | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | |||
@@ -752,6 +962,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEE | |||
// For mutil-batch | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; | |||
// For inserted op | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | |||
@@ -772,6 +983,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
// used for l1 fusion and other fusion in future | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | |||
@@ -782,10 +994,44 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; | |||
// functional ops attr | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; | |||
// used for label switch | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | |||
// Varible | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||
// HCOM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; | |||
// used for LX tiling | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_L1_SPACE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_TYPE_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST; | |||
// Dynamic stitch | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||
} // namespace ge | |||
#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ |
@@ -22,7 +22,7 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "graph/anchor.h" | |||
#include "detail/attributes_holder.h" | |||
#include "graph/detail/attributes_holder.h" | |||
#include "graph/ge_tensor.h" | |||
#include "graph/graph.h" | |||
#include "graph/node.h" | |||
@@ -77,6 +77,8 @@ class ModelSerializeImp { | |||
void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } | |||
private: | |||
bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map<std::string, ComputeGraphPtr> &subgraphs); | |||
std::vector<NodeNameGraphReq> graph_input_node_names_; | |||
std::vector<NodeNameGraphReq> graph_output_node_names_; | |||
std::vector<NodeNameNodeReq> node_input_node_names_; | |||
@@ -43,30 +43,31 @@ using ComputeGraphPtr = std::shared_ptr<ComputeGraph>; | |||
using ConstComputeGraphPtr = std::shared_ptr<const ComputeGraph>; | |||
class GeTensorDesc; | |||
class GeAttrValue; | |||
class GeAttrValueImp; | |||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { | |||
public: | |||
class NamedAttrs : public AttrHolder { | |||
public: | |||
NamedAttrs(); | |||
virtual ~NamedAttrs() = default; | |||
void SetName(const std::string &name); | |||
string GetName() const; | |||
GeAttrValue GetItem(const string &key) const; | |||
protected: | |||
ProtoAttrMapHelper MutableAttrMap() override; | |||
ConstProtoAttrMapHelper GetAttrMap() const override; | |||
private: | |||
// Create namedAttrs from protobuf obj | |||
NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); | |||
GeIrProtoHelper<proto::NamedAttrs> named_attrs_; | |||
friend class GeAttrValueImp; | |||
}; | |||
NamedAttrs(); | |||
virtual ~NamedAttrs() = default; | |||
void SetName(const std::string &name); | |||
string GetName() const; | |||
GeAttrValue GetItem(const string &key) const; | |||
protected: | |||
ProtoAttrMapHelper MutableAttrMap() override; | |||
ConstProtoAttrMapHelper GetAttrMap() const override; | |||
private: | |||
// Create namedAttrs from protobuf obj | |||
NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); | |||
GeIrProtoHelper<proto::NamedAttrs> named_attrs_; | |||
friend class GeAttrValueImp; | |||
friend class GeAttrValue; | |||
}; | |||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
public: | |||
using INT = int64_t; | |||
using FLOAT = float; | |||
using BOOL = bool; | |||
@@ -75,7 +76,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
using TENSOR_DESC = GeTensorDesc; | |||
using GRAPH = ComputeGraphPtr; | |||
using BYTES = Buffer; | |||
using NAMED_ATTRS = NamedAttrs; | |||
using NAMED_ATTRS = ge::NamedAttrs; | |||
using DATA_TYPE = ge::DataType; | |||
using LIST_INT = vector<INT>; | |||
@@ -90,6 +91,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
using LIST_LIST_INT = vector<vector<int64_t>>; | |||
using LIST_DATA_TYPE = vector<ge::DataType>; | |||
using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs). | |||
enum ValueType { | |||
VT_NONE = 0, | |||
VT_STRING, | |||
@@ -87,6 +87,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH | |||
GeShape &MutableShape(); | |||
void SetShape(GeShape shape); | |||
// set shape with -2, it stand for unknown shape | |||
void SetUnknownDimNumShape(); | |||
// for unknown shape | |||
graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range); | |||
graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const; | |||
GeShape GetOriginShape() const; | |||
void SetOriginShape(const GeShape &originShape); | |||
@@ -25,11 +25,7 @@ | |||
#include "graph/ge_attr_value.h" | |||
#include "graph/graph.h" | |||
namespace domi { | |||
class ModelHelper; | |||
} | |||
namespace ge { | |||
using domi::ModelHelper; | |||
using std::map; | |||
using std::string; | |||
using std::vector; | |||
@@ -50,6 +50,8 @@ class GeAttrValue; | |||
using ConstOpDesc = const OpDesc; | |||
enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd }; | |||
class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
public: | |||
template <class T> | |||
@@ -83,6 +85,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
graphStatus AddInputDescForward(const string &name, const unsigned int num); | |||
graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); | |||
graphStatus AddOutputDescForward(const string &name, const unsigned int num); | |||
graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); | |||
@@ -141,6 +145,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||
graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); | |||
graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||
bool IsOptionalInput(const string &name) const; | |||
@@ -214,6 +220,9 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
void SetIsInputConst(const vector<bool> &is_input_const); | |||
vector<bool> GetIsInputConst() const; | |||
void SetOpInferDepends(const vector<string> &depend_names); | |||
vector<string> GetOpInferDepends() const; | |||
string GetInputNameByIndex(uint32_t index) const; | |||
int GetInputIndexByName(const string &name) const; | |||
@@ -236,12 +245,23 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
std::string GetOpEngineName() const; | |||
void RegisterSubgraphIrName(const std::string &name, SubgraphType type); | |||
const std::map<std::string, SubgraphType> &GetSubgraphIrNames() const; | |||
SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; | |||
graphStatus AddSubgraphName(const std::string &name); | |||
const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const; | |||
std::string GetSubgraphInstanceName(uint32_t index) const; | |||
const std::vector<std::string> &GetSubgraphInstanceNames() const; | |||
void AddSubgraphInstanceName(std::string name); | |||
/// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, | |||
/// because this kind of functions will only append a new subgraph instance name | |||
/// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. | |||
/// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. | |||
/// \param index | |||
/// \param name | |||
/// \return | |||
graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); | |||
void RemoveSubgraphInstanceName(const std::string &name); | |||
protected: | |||
@@ -256,7 +276,23 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
GeIrProtoHelper<ge::proto::OpDef> op_def_; | |||
std::vector<std::string> subgraph_instance_names_; | |||
// subgraph names to index, for a `if` operator: | |||
// then_branch: 0 | |||
// else_branch: 1 | |||
// or for a `case` node: | |||
// branches0: 0 | |||
// branches1: 1 | |||
// branches2: 2 | |||
std::map<std::string, uint32_t> subgraph_names_to_index_; | |||
// subgraph ir names to type, for a `if` operator: | |||
// then_branch: static | |||
// else_branch: dynamic | |||
// or for a `case` op: | |||
// branches: dynamic | |||
std::map<std::string, SubgraphType> subgraph_ir_names_to_type_; | |||
vector<GeTensorDescPtr> inputs_desc_{}; | |||
vector<GeTensorDescPtr> outputs_desc_{}; | |||
map<string, uint32_t> output_name_idx_{}; | |||
@@ -0,0 +1,79 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef COMMON_GRAPH_REF_RELATION_H_ | |||
#define COMMON_GRAPH_REF_RELATION_H_ | |||
#include <deque> | |||
#include <string> | |||
#include <unordered_map> | |||
#include <vector> | |||
#include "graph/compute_graph.h" | |||
#include "graph/types.h" | |||
#include "graph/ge_error_codes.h" | |||
#include "node.h" | |||
namespace ge { | |||
enum InOutFlag { | |||
NODE_IN = 0, // input flag | |||
NODE_OUT = 1, // output flag | |||
}; | |||
struct RefCell { | |||
std::string node_name; | |||
ge::NodePtr node = nullptr; | |||
InOutFlag in_out = NODE_IN; | |||
int in_out_idx = 0; | |||
bool operator==(const RefCell &c) const { | |||
return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; | |||
} | |||
RefCell() = default; | |||
RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) { | |||
node_name = name; | |||
node = node_ptr; | |||
in_out = in_out_flag; | |||
in_out_idx = idx; | |||
}; | |||
~RefCell() = default; | |||
}; | |||
struct RefCellHash { | |||
size_t operator()(const RefCell &c) const { | |||
unsigned long number = reinterpret_cast<unsigned long>(reinterpret_cast<uintptr_t>(c.node.get())); | |||
string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + std::to_string(number); | |||
return std::hash<string>()(tmp); | |||
} | |||
}; | |||
class RefRelations { | |||
public: | |||
graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set<RefCell, RefCellHash> &result); | |||
graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); | |||
graphStatus Clear(); | |||
RefRelations(); | |||
~RefRelations() = default; | |||
public: | |||
class Impl; | |||
std::shared_ptr<Impl> impl_ = nullptr; | |||
}; | |||
} // namespace ge | |||
#endif // COMMON_GRAPH_REF_RELATION_H_ |
@@ -14,8 +14,8 @@ | |||
* limitations under the License. | |||
*/ | |||
#ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
#define INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
#ifndef INC_GRAPH_USR_TYPES_H_ | |||
#define INC_GRAPH_USR_TYPES_H_ | |||
#include <atomic> | |||
#include <memory> | |||
@@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { | |||
#undef USR_TYPE_BYTES_DEC | |||
} // namespace ge | |||
#endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
#endif // INC_GRAPH_USR_TYPES_H_ |
@@ -62,9 +62,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||
static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector<ComputeGraphPtr> &value); | |||
static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); | |||
static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector<GeAttrValue::BYTES> &value); | |||
static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NamedAttrs &value); | |||
static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value); | |||
static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, | |||
const vector<GeAttrValue::NamedAttrs> &value); | |||
const vector<GeAttrValue::NAMED_ATTRS> &value); | |||
static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value); | |||
static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<OpDescPtr> &value); | |||
@@ -91,9 +91,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||
static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector<ComputeGraphPtr> &value); | |||
static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); | |||
static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<GeAttrValue::BYTES> &value); | |||
static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NamedAttrs &value); | |||
static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value); | |||
static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, | |||
vector<GeAttrValue::NamedAttrs> &value); | |||
vector<GeAttrValue::NAMED_ATTRS> &value); | |||
static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<OpDescPtr> &value); | |||
// Value will be moved | |||
static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | |||
@@ -95,12 +95,35 @@ | |||
}; | |||
namespace ge { | |||
enum IOType { kIn, kOut }; | |||
struct NodeIndexIO { | |||
NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type) | |||
: node(std::move(node)), index(index), io_type(io_type) {} | |||
NodeIndexIO(ge::NodePtr node, int index, IOType io_type) | |||
: node(std::move(node)), index(static_cast<uint32_t>(index)), io_type(io_type) {} | |||
~NodeIndexIO() {} | |||
NodePtr node = nullptr; | |||
uint32_t index = 0; | |||
IOType io_type = kOut; | |||
std::string ToString() const { | |||
if ((node == nullptr) || (node->GetOwnerComputeGraph() == nullptr)) { | |||
return ""; | |||
} | |||
return node->GetName() + (io_type == kOut ? "_out_" : "_in_") + std::to_string(index); | |||
} | |||
}; | |||
class GraphUtils { | |||
public: | |||
static ComputeGraphPtr GetComputeGraph(const Graph &graph); | |||
static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); | |||
static graphStatus RecoverGraphOperators(const Graph &graph); | |||
static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs); | |||
static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | |||
@@ -262,6 +285,108 @@ class GraphUtils { | |||
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | |||
static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | |||
static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec); | |||
/// | |||
/// Get reference-mapping of all data_anchors in graph | |||
/// @param [in] graph | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
static graphStatus GetRefMapping(const ComputeGraphPtr &graph, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol); | |||
private: | |||
/// | |||
/// Get reference-mapping for in_data_anchors of node | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
static graphStatus HandleInAnchorMapping(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol); | |||
/// | |||
/// Get reference-mapping for out_data_anchors of node | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
static graphStatus HandleOutAnchorMapping(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol); | |||
/// | |||
/// Handle input of subgraph | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
static graphStatus HandleSubgraphInput(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol); | |||
/// | |||
/// Handle input of Merge op | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
static graphStatus HandleMergeInput(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol); | |||
/// | |||
/// Handle output of subgraph | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
static graphStatus HandleSubgraphOutput(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol); | |||
/// | |||
/// Union ref-mapping | |||
/// @param [in] exist_node_info1 | |||
/// @param [in] exist_node_info2 | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @param [out] symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol); | |||
/// | |||
/// Update symbol mapping with a new reference pair | |||
/// @param [in] cur_node_info | |||
/// @param [in] exist_node_info | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol); | |||
/// | |||
/// Check if out_data_anchor is reference of input | |||
/// @param [in] out_data_anchor | |||
/// @param [out] reuse_in_index | |||
/// @return bool | |||
/// | |||
static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); | |||
}; | |||
class ComputeGraphBuilder { | |||
@@ -441,12 +566,12 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { | |||
private: | |||
/// | |||
/// @brief Build inputs | |||
/// @brief Add data nodes | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void BuildInputs(graphStatus &error_code, std::string &error_msg); | |||
void AddDataNodes(graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Add data node | |||
@@ -455,41 +580,15 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||
NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Build outputs | |||
/// @brief Add RetVal nodes | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void BuildOutputs(graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Add NetOutput node | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return NodePtr | |||
/// | |||
NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Add input/output tensor for NetOutput node | |||
/// @param [in] out_nodes_info | |||
/// @param [out] net_output_desc | |||
/// @return graphStatus | |||
/// | |||
graphStatus BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
OpDescPtr &net_output_desc); | |||
/// | |||
/// @brief Add edge for NetOutput node | |||
/// @param [in] out_nodes_info | |||
/// @param [out] net_output_node | |||
/// @return graphStatus | |||
/// | |||
graphStatus AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
const NodePtr &net_output_node); | |||
void AddRetValNodes(graphStatus &error_code, std::string &error_msg); | |||
std::string name_; | |||
NodePtr parent_node_; | |||
@@ -55,11 +55,44 @@ class NodeUtils { | |||
static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); | |||
static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | |||
static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | |||
// check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; | |||
// for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, | |||
// the out param "is_unknow" will be true too | |||
static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); | |||
static std::string GetNodeType(const Node &node); | |||
static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | |||
static graphStatus AddSubgraph(Node &node, const ComputeGraphPtr &subgraph); | |||
static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); | |||
/// | |||
/// Check if node is input of subgraph | |||
/// @param [in] node | |||
/// @return bool | |||
/// | |||
static bool IsSubgraphInput(const NodePtr &node); | |||
/// | |||
/// Check if node is output of subgraph | |||
/// @param [in] node | |||
/// @return bool | |||
/// | |||
static bool IsSubgraphOutput(const NodePtr &node); | |||
/// | |||
/// @brief Get subgraph original input node. | |||
/// @param [in] node | |||
/// @return Node | |||
/// | |||
static NodePtr GetParentInput(const NodePtr &node); | |||
/// | |||
/// @brief Get subgraph input is constant. | |||
/// @param [in] node | |||
/// @param [out] string | |||
/// @return bool | |||
/// | |||
static bool GetConstOpType(const NodePtr &in_node, std::string &op_type); | |||
private: | |||
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | |||
@@ -81,6 +81,9 @@ class OpDescUtils { | |||
static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); | |||
static graphStatus SetSubgraphInstanceName(const std::string& subgraph_name, | |||
const std::string& subgraph_instance_name, OpDescPtr& op_desc); | |||
private: | |||
static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); | |||
static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | |||
@@ -104,6 +107,14 @@ class OpDescBuilder { | |||
/// | |||
OpDescBuilder& AddInput(const std::string& name); | |||
/// | |||
/// @brief Add input | |||
/// @param [in] name | |||
/// @param [in] tensor | |||
/// @return OpDescBuilder | |||
/// | |||
OpDescBuilder& AddInput(const std::string& name, const GeTensorDesc& tensor); | |||
/// | |||
/// @brief Add dynamic input | |||
/// @param [in] name | |||
@@ -112,6 +123,15 @@ class OpDescBuilder { | |||
/// | |||
OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); | |||
/// | |||
/// @brief Add dynamic input | |||
/// @param [in] name | |||
/// @param [in] num | |||
/// @param [in] tensor | |||
/// @return OpDescBuilder | |||
/// | |||
OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); | |||
/// | |||
/// @brief Add output | |||
/// @param [in] name | |||
@@ -119,6 +139,14 @@ class OpDescBuilder { | |||
/// | |||
OpDescBuilder& AddOutput(const std::string& name); | |||
/// | |||
/// @brief Add output | |||
/// @param [in] name | |||
/// @param [in] tensor | |||
/// @return OpDescBuilder | |||
/// | |||
OpDescBuilder& AddOutput(const std::string& name, const GeTensorDesc& tensor); | |||
/// | |||
/// @brief Add dynamic output | |||
/// @param [in] name | |||
@@ -127,6 +155,15 @@ class OpDescBuilder { | |||
/// | |||
OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); | |||
/// | |||
/// @brief Add dynamic output | |||
/// @param [in] name | |||
/// @param [in] num | |||
/// @param [in] tensor | |||
/// @return OpDescBuilder | |||
/// | |||
OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); | |||
/// | |||
/// @brief Build op_desc | |||
/// @return OpDescPtr | |||
@@ -136,8 +173,8 @@ class OpDescBuilder { | |||
private: | |||
std::string name_; | |||
std::string type_; | |||
std::vector<std::string> inputs_; | |||
std::vector<std::string> outputs_; | |||
std::vector<std::pair<std::string, GeTensorDesc>> inputs_; | |||
std::vector<std::pair<std::string, GeTensorDesc>> outputs_; | |||
}; | |||
} // namespace ge | |||
@@ -34,13 +34,12 @@ ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
ge_protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST}) | |||
# need to remove dependencies on pb files later | |||
file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"*.cc" | |||
"utils/*.cc" | |||
"opsproto/*.cc" | |||
"detail/*.cc" | |||
"debug/*.cc" | |||
"op_imp.cc" | |||
"option/*.cc" | |||
) | |||
@@ -53,7 +53,6 @@ void Anchor::UnlinkAll() noexcept { | |||
if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { | |||
GELOGW("unlink peer_anchor_ptr failed."); | |||
} | |||
} while (!peer_anchors_.empty()); | |||
} | |||
} | |||
@@ -42,8 +42,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const | |||
: name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { | |||
attrs_.InitDefault(); | |||
} | |||
ComputeGraph::~ComputeGraph() {} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; } | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { | |||
@@ -53,24 +56,50 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesS | |||
} | |||
return s; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | |||
vector<NodePtr> all_nodes(nodes_.size()); | |||
(void)std::copy(nodes_.begin(), nodes_.end(), all_nodes.begin()); | |||
for (const auto &sub_graph : sub_graph_) { | |||
if (sub_graph == nullptr) { | |||
GELOGW("sub graph is nullptr"); | |||
if (sub_graph_.empty()) { | |||
return Vistor<NodePtr>(shared_from_this(), nodes_); | |||
} | |||
std::vector<std::shared_ptr<ComputeGraph>> subgraphs; | |||
return AllGraphNodes(subgraphs); | |||
} | |||
ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const { | |||
std::vector<NodePtr> all_nodes; | |||
std::deque<NodePtr> candidates; | |||
candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); | |||
while (!candidates.empty()) { | |||
NodePtr node = candidates.front(); | |||
all_nodes.emplace_back(node); | |||
candidates.pop_front(); | |||
OpDescPtr op_desc = node->GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
continue; | |||
} | |||
for (const auto &node : sub_graph->GetAllNodes()) { | |||
all_nodes.push_back(node); | |||
const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||
for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { | |||
auto subgraph = GetSubgraph(*name_iter); | |||
if (subgraph != nullptr) { | |||
subgraphs.emplace_back(subgraph); | |||
candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); | |||
} | |||
} | |||
} | |||
return Vistor<NodePtr>(shared_from_this(), all_nodes); | |||
} | |||
size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const { | |||
return Vistor<NodePtr>(shared_from_this(), nodes_); | |||
} | |||
ComputeGraph::Vistor<NodePtr> ComputeGraph::GetInputNodes() const { | |||
return Vistor<NodePtr>(shared_from_this(), input_nodes_); | |||
} | |||
@@ -82,6 +111,7 @@ ComputeGraph::Vistor<NodePtr> ComputeGraph::GetOutputNodes() const { | |||
} | |||
return Vistor<NodePtr>(shared_from_this(), result); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { | |||
for (const auto &node : nodes_) { | |||
if (node == nullptr) { | |||
@@ -203,10 +233,6 @@ NodePtr ComputeGraph::AddNodeFront(NodePtr node) { | |||
return nullptr; | |||
} | |||
node->GetOpDesc()->SetId(nodes_.size()); | |||
if (nodes_[0] == nullptr) { | |||
GELOGE(GRAPH_FAILED, "nodes_ size or nodes_[0] is nullptr"); | |||
return nullptr; | |||
} | |||
if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { | |||
(void)nodes_.insert(nodes_.begin() + 1, node); | |||
} else { | |||
@@ -248,6 +274,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpD | |||
GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); | |||
return AddNode(node_ptr); | |||
} | |||
NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. | |||
if (op == nullptr) { | |||
GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); | |||
return nullptr; | |||
} | |||
op->SetId(id); | |||
NodePtr node = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this())); | |||
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); | |||
GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); | |||
nodes_.push_back(node); | |||
return node; | |||
} | |||
NodePtr ComputeGraph::AddInputNode(NodePtr node) { | |||
if (node == nullptr) { | |||
GELOGE(GRAPH_FAILED, "The node ptr should be not null."); | |||
@@ -259,6 +299,7 @@ NodePtr ComputeGraph::AddInputNode(NodePtr node) { | |||
} | |||
return node; | |||
} | |||
NodePtr ComputeGraph::AddOutputNode(NodePtr node) { | |||
if (node == nullptr || node->GetOpDesc() == nullptr) { | |||
GELOGE(GRAPH_FAILED, "The node ptr or opdesc should be not null."); | |||
@@ -336,6 +377,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveN | |||
} | |||
return GRAPH_FAILED; | |||
} | |||
// Used in sub_graph scenes | |||
graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { | |||
if (node == nullptr) { | |||
@@ -372,20 +414,24 @@ graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { | |||
GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED); | |||
return GRAPH_SUCCESS; | |||
} | |||
std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph) { | |||
if (sub_graph == nullptr) { | |||
GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | |||
return nullptr; | |||
} | |||
sub_graph_.push_back(sub_graph); | |||
names_to_subgraph_[sub_graph->GetName()] = sub_graph; | |||
return sub_graph; | |||
} | |||
graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph) { | |||
if (sub_graph == nullptr) { | |||
GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | |||
return GRAPH_FAILED; | |||
} | |||
names_to_subgraph_.erase(sub_graph->GetName()); | |||
auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); | |||
if (iter != sub_graph_.end()) { | |||
(void)sub_graph_.erase(iter); | |||
@@ -462,8 +508,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph( | |||
const std::string &name) const { | |||
auto iter = names_to_subgraph_.find(name); | |||
return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||
std::shared_ptr<ComputeGraph> parent = parent_graph_.lock(); | |||
if (parent == nullptr) { | |||
auto iter = names_to_subgraph_.find(name); | |||
return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||
} else { | |||
return parent->GetSubgraph(name); | |||
} | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>> | |||
@@ -495,7 +546,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode( | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | |||
for (auto &input : input_nodes_) { | |||
size_t update_num = 0; | |||
for (auto &input : nodes_) { | |||
if (update_num >= input_mapping.size()) { | |||
break; | |||
} | |||
uint32_t cur_index = 0; | |||
if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||
continue; | |||
@@ -508,6 +563,7 @@ ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mappi | |||
GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||
return GRAPH_FAILED; | |||
} | |||
update_num++; | |||
} | |||
return GRAPH_SUCCESS; | |||
@@ -520,9 +576,9 @@ ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mappi | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) { | |||
NodePtr net_output = FindNode(kNodeNameNetOutput); | |||
NodePtr net_output = FindNode(NODE_NAME_NET_OUTPUT); | |||
if (net_output == nullptr) { | |||
GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput); | |||
GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", NODE_NAME_NET_OUTPUT); | |||
return GRAPH_FAILED; | |||
} | |||
OpDescPtr op_desc = net_output->GetOpDesc(); | |||
@@ -557,13 +613,13 @@ ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_map | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | |||
std::vector<NodePtr> node_vec = nodes_; | |||
for (const auto &node : GetAllNodes()) { | |||
for (const auto &node : GetDirectNode()) { | |||
if (node == nullptr || node->GetOpDesc() == nullptr) { | |||
GELOGW("node or OpDescPtr is nullptr."); | |||
continue; | |||
} | |||
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should be not null."); return GRAPH_FAILED); | |||
if (node->GetOpDesc()->GetType() == kRecvType) { | |||
if (node->GetOpDesc()->GetType() == RECV) { | |||
auto iter = find(node_vec.begin(), node_vec.end(), node); | |||
if (iter == node_vec.end()) { | |||
GELOGW("no node found."); | |||
@@ -574,7 +630,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||
auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0)); | |||
(void)node_vec.insert(dst_iter, node); | |||
} | |||
if (node->GetOpDesc()->GetType() == kSendType) { | |||
if (node->GetOpDesc()->GetType() == SEND) { | |||
auto iter = find(node_vec.begin(), node_vec.end(), node); | |||
if (iter == node_vec.end()) { | |||
GELOGW("no node found."); | |||
@@ -602,7 +658,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||
graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||
std::map<NodePtr, uint32_t> &map_in_edge_num, | |||
std::vector<NodePtr> &stack) { | |||
GELOGI("Runing_Dfs_Sort"); | |||
GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); | |||
// Record the number of non data nodes but no input nodes | |||
GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | |||
@@ -647,7 +703,7 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||
graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||
std::map<NodePtr, uint32_t> &map_in_edge_num, | |||
std::deque<NodePtr> &stack) { | |||
GELOGI("Runing_Bfs_Sort"); | |||
GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); | |||
std::vector<NodePtr> stack_input; | |||
std::map<string, NodePtr> breadth_node_map; | |||
// Record the number of non data nodes but no input nodes | |||
@@ -708,23 +764,36 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | |||
auto ret = TopologicalSortingSubgraph(); | |||
auto ret = TopologicalSortingGraph(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Sub graph partition Failed"); | |||
return ret; | |||
} | |||
if (sub_graph_.empty()) { | |||
return SUCCESS; | |||
} | |||
// partition sub graph | |||
for (const auto &sub_graph : GetAllSubgraphs()) { | |||
ret = sub_graph->TopologicalSortingSubgraph(); | |||
for (const auto &sub_graph : sub_graph_) { | |||
ret = sub_graph->TopologicalSortingGraph(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Sub graph topological sort Failed"); | |||
return ret; | |||
} | |||
} | |||
std::vector<std::shared_ptr<ComputeGraph>> subgraphs; | |||
(void)AllGraphNodes(subgraphs); | |||
if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original | |||
GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size()); | |||
return SUCCESS; | |||
} | |||
sub_graph_.swap(subgraphs); | |||
return SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() { | |||
graphStatus ComputeGraph::TopologicalSortingGraph() { | |||
std::vector<NodePtr> node_vec; | |||
std::map<NodePtr, uint32_t> map_in_edge_num; | |||
bool use_BFS = false; | |||
@@ -735,7 +804,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||
use_BFS = true; | |||
} | |||
} else { | |||
GELOGW("Get OPTION_GRAPH_RUN_MODE failed, use BFSTopologicalSorting by default."); | |||
GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); | |||
} | |||
if (use_BFS) { | |||
@@ -793,8 +862,8 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | |||
if (map_in_edge_num[node] == 0) { | |||
if ((node->GetOpDesc()->GetType() != kDataType) && (node->GetOpDesc()->GetType() != kAippDataType) && | |||
(node->GetOpDesc()->GetType() != kInputType) && (node->GetOpDesc()->GetType() != kAnnDataType)) { | |||
if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) && | |||
(node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) { | |||
// At present, can only judge the isolated point without input and output. | |||
// It is impossible to judge the situation with multiple output nodes. | |||
if (verify_isolated && GetOutEdgeSize(node) == 0) { | |||
@@ -832,6 +901,7 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||
return GRAPH_SUCCESS; | |||
} | |||
size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||
size_t in_edge_size = 0; | |||
if (node == nullptr) { | |||
@@ -884,6 +954,7 @@ size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; } | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||
GELOGI("graph name = %s.", GetName().c_str()); | |||
for (const auto &node : GetAllNodes()) { | |||
@@ -915,6 +986,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||
} | |||
} | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { | |||
GE_CHECK_NOTNULL(node); | |||
auto next_nodes = node->GetOutAllNodes(); | |||
@@ -954,6 +1026,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Isolate | |||
} | |||
} | |||
} | |||
// If there is an input control side | |||
auto in_ctrl_anchor = node->GetInControlAnchor(); | |||
GE_CHECK_NOTNULL(in_ctrl_anchor); | |||
@@ -991,6 +1064,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Isolate | |||
return RemoveExtraOutEdge(node); | |||
} | |||
graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { | |||
GE_CHECK_NOTNULL(node); | |||
// Remove redundant output edges | |||
@@ -1041,7 +1115,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferSh | |||
node_ptr->GetName().c_str()); | |||
graphStatus status = node_ptr->InferShapeAndType(); | |||
GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == kDataType || GRAPH_PARAM_INVALID != status, break, | |||
GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break, | |||
"Op %s does not have the IMPLEMT_INFERFUNC definition," | |||
" and subsequent operators no longer perform shape inference.", | |||
node_ptr->GetName().c_str()); | |||
@@ -16,237 +16,41 @@ | |||
#ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | |||
#define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | |||
#include <limits.h> | |||
#include <stdint.h> | |||
#include <algorithm> | |||
#include <map> | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
namespace ge { | |||
#define GE_REGISTER_OPTYPE(var_name, str_name) static const char* var_name __attribute__((unused)) = str_name | |||
#define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name | |||
GE_REGISTER_OPTYPE(DATA, "Data"); | |||
GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); | |||
GE_REGISTER_OPTYPE(CONVOLUTION, "Convolution"); | |||
GE_REGISTER_OPTYPE(CORRELATION, "Correlation"); | |||
GE_REGISTER_OPTYPE(CORRELATIONV2, "Correlation_V2"); | |||
GE_REGISTER_OPTYPE(DECONVOLUTION, "Deconvolution"); | |||
GE_REGISTER_OPTYPE(POOLING, "Pooling"); | |||
GE_REGISTER_OPTYPE(ELTWISE, "Eltwise"); | |||
GE_REGISTER_OPTYPE(RELU, "ReLU"); | |||
GE_REGISTER_OPTYPE(RELU6, "ReLU6"); | |||
GE_REGISTER_OPTYPE(SIGMOID, "Sigmoid"); | |||
GE_REGISTER_OPTYPE(ABSVAL, "AbsVal"); | |||
GE_REGISTER_OPTYPE(TANH, "TanH"); | |||
GE_REGISTER_OPTYPE(PRELU, "PReLU"); | |||
GE_REGISTER_OPTYPE(BATCHNORM, "BatchNorm"); | |||
GE_REGISTER_OPTYPE(FUSIONBATCHNORM, "FusionBatchNorm"); | |||
GE_REGISTER_OPTYPE(SCALE, "Scale"); | |||
GE_REGISTER_OPTYPE(FULL_CONNECTION, "FullConnection"); | |||
GE_REGISTER_OPTYPE(SOFTMAX, "Softmax"); | |||
GE_REGISTER_OPTYPE(PLUS, "Plus"); | |||
GE_REGISTER_OPTYPE(ACTIVATION, "Activation"); | |||
GE_REGISTER_OPTYPE(FLATTEN, "Flatten"); | |||
GE_REGISTER_OPTYPE(ADD, "Add"); | |||
GE_REGISTER_OPTYPE(SUB, "Sub"); | |||
GE_REGISTER_OPTYPE(MUL, "Mul"); | |||
GE_REGISTER_OPTYPE(MATMUL, "MatMul"); | |||
GE_REGISTER_OPTYPE(RSQRT, "Rsqrt"); | |||
GE_REGISTER_OPTYPE(BIASADD, "BiasAdd"); | |||
GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); | |||
GE_REGISTER_OPTYPE(DEPCONVOLUTION, "ConvolutionDepthwise"); | |||
GE_REGISTER_OPTYPE(DROPOUT, "Dropout"); | |||
GE_REGISTER_OPTYPE(CONCAT, "Concat"); | |||
GE_REGISTER_OPTYPE(ROIPOOLING, "ROIPooling"); | |||
GE_REGISTER_OPTYPE(PROPOSAL, "Proposal"); | |||
GE_REGISTER_OPTYPE(FSRDETECTIONOUTPUT, "FSRDetectionOutput"); | |||
GE_REGISTER_OPTYPE(DETECTIONPOSTPROCESS, "Detectpostprocess"); | |||
GE_REGISTER_OPTYPE(LRN, "LRN"); | |||
GE_REGISTER_OPTYPE(TRANSDATA, "TransData"); | |||
GE_REGISTER_OPTYPE(PERMUTE, "Permute"); | |||
GE_REGISTER_OPTYPE(SSDNORMALIZE, "SSDNormalize"); | |||
GE_REGISTER_OPTYPE(SSDPRIORBOX, "SSDPriorBox"); | |||
GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); | |||
GE_REGISTER_OPTYPE(SSDDETECTIONOUTPUT, "SSDDetectionOutput"); | |||
GE_REGISTER_OPTYPE(CHANNELAXPY, "ChannelAxpy"); | |||
GE_REGISTER_OPTYPE(PSROIPOOLING, "PSROIPooling"); | |||
GE_REGISTER_OPTYPE(POWER, "Power"); | |||
GE_REGISTER_OPTYPE(ROIALIGN, "ROIAlign"); | |||
GE_REGISTER_OPTYPE(PYTHON, "Python"); | |||
GE_REGISTER_OPTYPE(FREESPACEEXTRACT, "FreespaceExtract"); | |||
GE_REGISTER_OPTYPE(SPATIALTF, "SpatialTransform"); | |||
GE_REGISTER_OPTYPE(SHAPE, "Shape"); | |||
GE_REGISTER_OPTYPE(ARGMAX, "ArgMax"); | |||
GE_REGISTER_OPTYPE(GATHERND, "GatherNd"); | |||
GE_REGISTER_OPTYPE(GATHER, "Gather"); | |||
GE_REGISTER_OPTYPE(REALDIV, "RealDiv"); | |||
GE_REGISTER_OPTYPE(PACK, "Pack"); | |||
GE_REGISTER_OPTYPE(SLICE, "Slice"); | |||
GE_REGISTER_OPTYPE(FLOORDIV, "FloorDiv"); | |||
GE_REGISTER_OPTYPE(_WHILE, "_While"); | |||
GE_REGISTER_OPTYPE(WHILE, "While"); | |||
GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); | |||
GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); | |||
GE_REGISTER_OPTYPE(STRIDEDSLICE, "StridedSlice"); | |||
GE_REGISTER_OPTYPE(RANGE, "Range"); | |||
GE_REGISTER_OPTYPE(RPNPROPOSALS, "GenerateRpnProposals"); | |||
GE_REGISTER_OPTYPE(DECODEBBOX, "DecodeBBox"); | |||
GE_REGISTER_OPTYPE(PAD, "Pad"); | |||
GE_REGISTER_OPTYPE(TILE, "Tile"); | |||
GE_REGISTER_OPTYPE(SIZE, "Size"); | |||
GE_REGISTER_OPTYPE(CLIPBOXES, "Clipboxes"); | |||
GE_REGISTER_OPTYPE(FASTRCNNPREDICTIONS, "FastrcnnPredictions"); | |||
GE_REGISTER_OPTYPE(SPLIT, "Split"); | |||
GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | |||
GE_REGISTER_OPTYPE(MEAN, "Mean"); | |||
GE_REGISTER_OPTYPE(GREATER, "Greater"); | |||
GE_REGISTER_OPTYPE(SWITCH, "Switch"); | |||
GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); | |||
GE_REGISTER_OPTYPE(MERGE, "Merge"); | |||
GE_REGISTER_OPTYPE(REFMERGE, "RefMerge"); | |||
GE_REGISTER_OPTYPE(ENTER, "Enter"); | |||
GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); | |||
GE_REGISTER_OPTYPE(LOOPCOND, "LoopCond"); | |||
GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); | |||
GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); | |||
GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | |||
GE_REGISTER_OPTYPE(EXIT, "Exit"); | |||
GE_REGISTER_OPTYPE(REFEXIT, "RefExit"); | |||
GE_REGISTER_OPTYPE(CONTROLTRIGGER, "ControlTrigger"); | |||
GE_REGISTER_OPTYPE(TRANSPOSE, "Transpose"); | |||
GE_REGISTER_OPTYPE(CAST, "Cast"); | |||
GE_REGISTER_OPTYPE(REGION, "Region"); | |||
GE_REGISTER_OPTYPE(YOLO, "Yolo"); | |||
GE_REGISTER_OPTYPE(YOLODETECTIONOUTPUT, "YoloDetectionOutput"); | |||
GE_REGISTER_OPTYPE(FILL, "Fill"); | |||
GE_REGISTER_OPTYPE(REVERSE, "Reverse"); | |||
GE_REGISTER_OPTYPE(UNPACK, "Unpack"); | |||
GE_REGISTER_OPTYPE(YOLO2REORG, "Yolo2Reorg"); | |||
GE_REGISTER_OPTYPE(REDUCESUM, "ReduceSum"); | |||
GE_REGISTER_OPTYPE(CONSTANT, "Const"); | |||
GE_REGISTER_OPTYPE(RESIZEBILINEAR, "ResizeBilinear"); | |||
GE_REGISTER_OPTYPE(MAXIMUM, "Maximum"); | |||
GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | |||
GE_REGISTER_OPTYPE(ARG, "_Arg"); | |||
GE_REGISTER_OPTYPE(FUSEDBATCHNORMGRAD, "FusedBatchNormGrad"); | |||
GE_REGISTER_OPTYPE(LSTM, "LSTM"); | |||
GE_REGISTER_OPTYPE(HIGHWAY, "HighWay"); | |||
GE_REGISTER_OPTYPE(RNN, "RNN"); | |||
GE_REGISTER_OPTYPE(ATTENTIONDECODER, "AttentionDecoder"); | |||
GE_REGISTER_OPTYPE(LOGICAL_NOT, "LogicalNot"); | |||
GE_REGISTER_OPTYPE(LOGICAL_AND, "LogicalAnd"); | |||
GE_REGISTER_OPTYPE(EQUAL, "Equal"); | |||
GE_REGISTER_OPTYPE(INTERP, "Interp"); | |||
GE_REGISTER_OPTYPE(SHUFFLECHANNEL, "ShuffleChannel"); | |||
GE_REGISTER_OPTYPE(AIPP, "Aipp"); | |||
GE_REGISTER_OPTYPE(CROPANDRESIZE, "CropAndResize"); | |||
GE_REGISTER_OPTYPE(UNUSEDCONST, "UnusedConst"); | |||
GE_REGISTER_OPTYPE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs"); | |||
GE_REGISTER_OPTYPE(BROADCASTARGS, "BroadcastArgs"); | |||
GE_REGISTER_OPTYPE(STOPGRADIENT, "StopGradient"); | |||
GE_REGISTER_OPTYPE(PPREVENTGRADIENT, "PreventGradient"); | |||
GE_REGISTER_OPTYPE(GUARANTEECONST, "GuaranteeConst"); | |||
GE_REGISTER_OPTYPE(SPARSETODENSE, "SparseToDense"); | |||
GE_REGISTER_OPTYPE(NONMAXSUPPRESSION, "NonMaxSuppression"); | |||
GE_REGISTER_OPTYPE(TOPKV2, "TopKV2"); | |||
GE_REGISTER_OPTYPE(INVERTPERMUTATION, "InvertPermutation"); | |||
GE_REGISTER_OPTYPE(MULTINOMIAL, "Multinomial"); | |||
GE_REGISTER_OPTYPE(REVERSESEQUENCE, "ReverseSequence"); | |||
GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | |||
GE_REGISTER_OPTYPE(INITDATA, "InitData"); | |||
// ANN specific operator | |||
GE_REGISTER_OPTYPE(ANN_MEAN, "AnnMean"); | |||
GE_REGISTER_OPTYPE(ANN_CONVOLUTION, "AnnConvolution"); | |||
GE_REGISTER_OPTYPE(ANN_DEPCONVOLUTION, "AnnDepthConv"); | |||
GE_REGISTER_OPTYPE(DIV, "Div"); | |||
GE_REGISTER_OPTYPE(ANN_FULLCONNECTION, "AnnFullConnection"); | |||
GE_REGISTER_OPTYPE(ANN_NETOUTPUT, "AnnNetOutput"); | |||
GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); | |||
// Training operator | |||
GE_REGISTER_OPTYPE(CONVGRADFILTER, "Conv2DBackpropFilter"); | |||
GE_REGISTER_OPTYPE(CONV2D, "Conv2D"); | |||
GE_REGISTER_OPTYPE(CONV2DBACKPROPINPUT, "Conv2DBackpropInput"); | |||
GE_REGISTER_OPTYPE(ACTIVATIONGRAD, "ReluGrad"); | |||
GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); | |||
GE_REGISTER_OPTYPE(AVGPOOLGRAD, "AvgPoolGrad"); | |||
GE_REGISTER_OPTYPE(SQUARE, "Square"); | |||
GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); | |||
GE_REGISTER_OPTYPE(END, "End"); | |||
GE_REGISTER_OPTYPE(VARIABLE, "Variable"); | |||
GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); | |||
/// @ingroup domi_omg | |||
/// @brief INPUT node type | |||
static const char* const kInputType = "Input"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief AIPP tag, tag for aipp conv operator | |||
/// | |||
static const char* const kAippConvFlag = "Aipp_Conv_Flag"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief AIPP tag, tag for aipp data operator | |||
/// | |||
static const char* const kAippDataFlag = "Aipp_Data_Flag"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief AIPP tag, tag for aipp data operator | |||
/// | |||
static const char* const kAippDataType = "AippData"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief DATA node type | |||
/// | |||
static const char* const kDataType = "Data"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief Frame operator type | |||
/// | |||
static const char* const kFrameworkOpType = "FrameworkOp"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief Data node type | |||
/// | |||
static const char* const kAnnDataType = "AnnData"; | |||
static const char* const kAnnNetoutputType = "AnnNetOutput"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief Convolution node type | |||
/// | |||
static const char* const kNodeNameNetOutput = "Node_Output"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief RECV node type | |||
/// | |||
static const char* const kRecvType = "Recv"; | |||
GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief SEND node type | |||
/// | |||
static const char* const kSendType = "Send"; | |||
GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief Convolution node type | |||
/// | |||
static const char* const kOpTypeConvolution = "Convolution"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief Add convolution node name to hard AIPP | |||
/// | |||
static const char* const kAippConvOpNmae = "aipp_conv_op"; | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief Operator configuration item separator | |||
/// | |||
static const char* const kOpConfDelimiter = ":"; | |||
GE_REGISTER_OPTYPE(RECV, "Recv"); | |||
GE_REGISTER_OPTYPE(SEND, "Send"); | |||
}; // namespace ge | |||
#endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ |
@@ -15,11 +15,14 @@ | |||
*/ | |||
#include "format_refiner.h" | |||
#include <deque> | |||
#include <iostream> | |||
#include <set> | |||
#include <unordered_map> | |||
#include <unordered_set> | |||
#include "graph/ref_relation.h" | |||
#include "./compute_graph.h" | |||
#include "./ge_error_codes.h" | |||
#include "./graph/ge_tensor.h" | |||
@@ -34,14 +37,41 @@ | |||
#include "utils/tensor_utils.h" | |||
#include "utils/type_utils.h" | |||
using namespace ge; | |||
using namespace std; | |||
namespace ge { | |||
namespace { | |||
static const std::unordered_set<string> kChangeDimNodes = {RESHAPE, PERMUTE, EXPANDDIMS, SQUEEZE}; | |||
static bool net_format_is_nd = true; | |||
static Format g_user_set_format = FORMAT_ND; | |||
static bool is_first_infer = true; | |||
static RefRelations reflection_builder; | |||
} // namespace | |||
graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection, | |||
std::deque<ge::NodePtr> &nodes, ge::Format to_be_set_format) { | |||
for (const auto &cell : reflection) { | |||
auto node = cell.node; | |||
auto in_out_idx = cell.in_out_idx; | |||
GE_CHECK_NOTNULL(node); | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
if (cell.in_out == ge::NODE_IN) { | |||
auto desc = node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(in_out_idx)); | |||
desc.SetOriginFormat(to_be_set_format); | |||
desc.SetFormat(to_be_set_format); | |||
(void)node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(in_out_idx), desc); | |||
} else { | |||
auto desc = node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(in_out_idx)); | |||
desc.SetOriginFormat(to_be_set_format); | |||
desc.SetFormat(to_be_set_format); | |||
(void)node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(in_out_idx), desc); | |||
} | |||
nodes.push_back(cell.node); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | |||
GE_CHECK_NOTNULL(op_desc); | |||
if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { | |||
@@ -66,7 +96,6 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
anchor_points.clear(); | |||
// Get all anchor point nodes and switch nodes | |||
for (const auto &node_ptr : graph->GetAllNodes()) { | |||
std::vector<bool> is_node_set_format; | |||
if (node_ptr == nullptr) { | |||
return GRAPH_FAILED; | |||
} | |||
@@ -86,7 +115,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
for (uint32_t i = 0; i < input_size; i++) { | |||
// Operator pre-set format but not origin format | |||
auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); | |||
// Pre-save data node and default infer fail | |||
// Pre-save data node (only main graph data) and default infer fail | |||
if (node_ptr->GetType() == DATA) { | |||
data_nodes.push_back(node_ptr); | |||
} | |||
@@ -163,6 +192,16 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
} | |||
// Check format whether have been set | |||
int idx = peer_out_data_anchor->GetIdx(); | |||
// do peer_out_node name and index as key to lookup reflections | |||
ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); | |||
std::unordered_set<RefCell, RefCellHash> reflection; | |||
auto status = reflection_builder.LookUpRefRelations(key, reflection); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge", | |||
(peer_out_data_node->GetName()).c_str(), idx); | |||
return GRAPH_FAILED; | |||
} | |||
auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx)); | |||
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||
auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||
@@ -181,18 +220,26 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
continue; | |||
} | |||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
ge_tensor_desc.SetFormat(to_be_set_format); | |||
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||
if (reflection.empty()) { | |||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
ge_tensor_desc.SetFormat(to_be_set_format); | |||
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||
// Call operator infer format api (forward) to get out format | |||
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | |||
graphStatus status = peer_out_data_node->InferOriginFormat(); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); | |||
return GRAPH_FAILED; | |||
// Call operator infer format api (forward) to get out format | |||
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | |||
status = peer_out_data_node->InferOriginFormat(); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
nodes.push_back(peer_out_data_node); | |||
} else { | |||
auto status = ReflectionProcess(reflection, nodes, to_be_set_format); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "reflection process failed!"); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
nodes.push_back(peer_out_data_node); | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
@@ -213,17 +260,23 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||
continue; | |||
} | |||
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
if (peer_in_data_anchor == nullptr) { | |||
GELOGW("Node[%s] some peer_in_anchor is null", (node->GetName()).c_str()); | |||
continue; | |||
} | |||
GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); | |||
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | |||
if (peer_in_data_node == nullptr || peer_in_data_node->GetOpDesc() == nullptr) { | |||
GELOGW("Node[%s] peer_in_data_node or peer_in_data_node desc is null", node->GetName().c_str()); | |||
continue; | |||
} | |||
GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); | |||
GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue); | |||
// Check format whether have been set | |||
int idx = peer_in_data_anchor->GetIdx(); | |||
// do peer_out_node name and index as key to lookup reflections | |||
ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); | |||
std::unordered_set<RefCell, RefCellHash> reflection; | |||
auto status = reflection_builder.LookUpRefRelations(key, reflection); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge", | |||
(peer_in_data_node->GetName()).c_str(), idx); | |||
return GRAPH_FAILED; | |||
} | |||
auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx)); | |||
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||
auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||
@@ -240,24 +293,33 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||
GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); | |||
continue; | |||
} | |||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
ge_tensor_desc.SetFormat(to_be_set_format); | |||
(void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(idx, ge_tensor_desc); | |||
/// Because netoutput node added before infer format ,so netoutput is end condition | |||
/// must set netoutput format , because saved result depend on format | |||
if (peer_in_data_node_type == NETOUTPUT) { | |||
continue; | |||
} | |||
if (reflection.empty()) { | |||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
ge_tensor_desc.SetFormat(to_be_set_format); | |||
(void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||
// Call operator infer format api (forward) to get out format | |||
GELOGD("call infer format func[Forward]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); | |||
graphStatus status = peer_in_data_node->InferOriginFormat(); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); | |||
return GRAPH_FAILED; | |||
/// Because netoutput node added before infer format ,so netoutput is end condition | |||
/// must set netoutput format , because saved result depend on format | |||
if (peer_in_data_node_type == NETOUTPUT) { | |||
continue; | |||
} | |||
// Call operator infer format api (forward) to get out format | |||
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); | |||
status = peer_in_data_node->InferOriginFormat(); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
nodes.push_back(peer_in_data_node); | |||
} else { | |||
auto status = ReflectionProcess(reflection, nodes, to_be_set_format); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "reflection process failed!"); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
nodes.push_back(peer_in_data_node); | |||
} | |||
} | |||
} | |||
@@ -355,8 +417,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
GELOGE(GRAPH_FAILED, "input graph is null"); | |||
return GRAPH_FAILED; | |||
} | |||
// build reflection relations of boundary | |||
(void)reflection_builder.Clear(); | |||
auto status = reflection_builder.BuildRefRelations(*graph); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!"); | |||
return GRAPH_FAILED; | |||
} | |||
// User set global net format | |||
graphStatus status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); | |||
status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); | |||
return GRAPH_FAILED; | |||
@@ -18,6 +18,12 @@ | |||
namespace ge { | |||
// Public attribute | |||
const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; | |||
const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; | |||
const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE = "_unknown_shape_type"; | |||
const std::string ATTR_NAME_NAME = "name"; | |||
const std::string ATTR_NAME_TYPE = "type"; | |||
@@ -42,6 +48,8 @@ const std::string ATTR_NAME_BIAS = "bias"; | |||
const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | |||
const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; | |||
const std::string ATTR_NAME_PAD = "pad"; | |||
const std::string ATTR_NAME_PADS = "pad"; | |||
@@ -83,6 +91,7 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; | |||
const std::string ATTR_NAME_AXIS = "axis"; | |||
const std::string ATTR_NAME_BROADCAST = "broadcast"; | |||
const std::string ATTR_NAME_OUTPUT = "output"; | |||
const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | |||
const std::string ATTR_NAME_TIDX = "t_idx"; | |||
@@ -103,6 +112,13 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; | |||
const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | |||
const std::string ATTR_NAME_AIPP = "aipp"; | |||
const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | |||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | |||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | |||
const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||
const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | |||
const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | |||
@@ -111,6 +127,7 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; | |||
const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | |||
const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | |||
const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | |||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||
const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | |||
const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | |||
@@ -122,9 +139,12 @@ const std::string ATTR_NAME_WEIGHTS = "value"; | |||
const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | |||
const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | |||
const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | |||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||
const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | |||
const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; | |||
const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | |||
const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | |||
const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||
const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||
// To be deleted | |||
const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | |||
@@ -138,15 +158,13 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; | |||
const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | |||
const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||
// Refinedet | |||
const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | |||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | |||
const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | |||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||
// _Arg | |||
const std::string ATTR_NAME_INDEX = "index"; | |||
@@ -236,6 +254,30 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; | |||
const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | |||
const std::string BATCHNORM_ATTR_SCALE = "scale"; | |||
const std::string BATCHNORM_ATTR_BIAS = "bias"; | |||
const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; | |||
const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; | |||
const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; | |||
// huberloss | |||
const std::string HUBER_LOSS_ATTR_DELTA = "delta"; | |||
// SSDRealDivTileMul | |||
const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; | |||
// SSDSumMulRealDivMean | |||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; | |||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; | |||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; | |||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; | |||
// ConcatFive2Four | |||
// ConcatFour2Five | |||
const std::string SSD_BOX_TYPE_NUM = "box_type_num"; | |||
const std::string SSD_CLASS_NUM = "class_num"; | |||
const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; | |||
const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; | |||
const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; | |||
const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; | |||
// Scale | |||
const std::string SCALE_ATTR_SCALE = "scale"; | |||
@@ -340,6 +382,7 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; | |||
// Permute | |||
const std::string PERMUTE_ATTR_ORDER = "order"; | |||
const std::string PERMUTE_ATTR_PERM = "perm"; | |||
// SSD Normalize | |||
const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | |||
@@ -367,6 +410,10 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; | |||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
// RefinedetDetectionOutput | |||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
// PRelu | |||
const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | |||
@@ -380,11 +427,16 @@ const std::string POWER_ATTR_NAME_POWER = "power"; | |||
const std::string POWER_ATTR_NAME_SCALE = "scale"; | |||
const std::string POWER_ATTR_NAME_SHIFT = "shift"; | |||
// log | |||
const std::string LOG_ATTR_NAME_SCALE = "scale"; | |||
const std::string LOG_ATTR_NAME_SHIFT = "shift"; | |||
const std::string LOG_ATTR_NAME_BASE = "base"; | |||
// Pack | |||
const std::string PACK_ATTR_NAME_NUM = "N"; | |||
// Unpack | |||
const std::string UNPACK_ATTR_NAME_NUM = "num"; | |||
const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||
// Gathernd | |||
const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | |||
const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | |||
@@ -394,6 +446,13 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; | |||
const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | |||
const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | |||
const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | |||
const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; | |||
const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; | |||
const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; | |||
// upsample | |||
const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; | |||
const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; | |||
// Relu | |||
const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | |||
@@ -531,19 +590,41 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape | |||
const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | |||
// Rnn | |||
const std::string RNN_MODE_ = "rnn_"; | |||
const std::string CNN_RNN = "cnn_rnn"; | |||
const std::string RNN_MODE_STATIC = "rnn_static"; | |||
const std::string MUTI_RNN = "multi_rnn"; | |||
const std::string CNN_RNN = "cnn_rnn"; | |||
const std::string RNN_MODE_ = "rnn_"; | |||
const std::string CELL_MODE = "mode"; | |||
const std::string LSTM_CELL = "lstm_cell"; | |||
const std::string GRU_CELL = "gru_cell"; | |||
const std::string RNN_HT = "ht"; | |||
const std::string RNN_XT_HT = "xt_ht"; | |||
const std::string RNN_BATCH_SIZE = "batch_size"; | |||
const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; | |||
const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; | |||
const std::string LSTM_ACTIVATE = "lstm_activate"; | |||
const std::string LSTM_OUT_MAP = "lstm_out_map"; | |||
const std::string LSTM_OUT_MODE = "lstm_out_mode"; | |||
const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; | |||
const std::string LSTM_TIME_MAJOR = "lstm_time_major"; | |||
const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; | |||
// Upsample | |||
const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | |||
// PadV2 | |||
const std::string PADV2_ATTR_NAME_MODE = "mode"; | |||
const std::string PADV2_ATTR_NAME_PADS = "paddings"; | |||
const std::string PADV2_ATTR_NAME_T = "T"; | |||
const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||
const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; | |||
// MirrorPad | |||
const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; | |||
const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; | |||
const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||
const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; | |||
// Filler | |||
const std::string FILLER_TYPE = "filler_type"; | |||
const std::string FILLER_VALUE = "filler_value"; | |||
@@ -554,9 +635,6 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; | |||
// TopKV2 | |||
const std::string TOPKV2_ATTR_K = "k"; | |||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||
// Calibaration | |||
const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | |||
const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | |||
@@ -611,10 +689,14 @@ const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; | |||
const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | |||
const std::string ATTR_MODEL_HUGE_STREAM_LIST = "huge_stream_list"; | |||
const std::string ATTR_MODEL_LABEL_NUM = "label_num"; | |||
const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | |||
const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; | |||
const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | |||
const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | |||
@@ -660,8 +742,125 @@ const std::string TARGET_TYPE_TINY = "TINY"; | |||
const std::string TARGET_TYPE_LITE = "LITE"; | |||
// l2_normalize | |||
const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; | |||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||
const std::string POOL_PARAMA_ATTR_WINDOW = "window"; | |||
const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; | |||
const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; | |||
const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; | |||
const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; | |||
const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; | |||
// HCOM | |||
const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||
const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||
const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||
const std::string HCOM_ATTR_GROUP = "group"; | |||
const std::string HCOM_ATTR_SR_TAG = "sr_tag"; | |||
const std::string HCOM_ATTR_SRC_RANK = "src_rank"; | |||
const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; | |||
const std::string HCOM_ATTR_FUSION = "fusion"; | |||
const std::string HCOM_ATTR_SHAPE = "shape"; | |||
const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||
// SpaceToDepth/DepthToSpace | |||
const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; | |||
// SparseSoftmaxCrossEntropyWithLogits | |||
const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; | |||
// MaxPoolGradWithArgmax | |||
const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; | |||
// AvgPoolGrad | |||
const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; | |||
// Pad | |||
const std::string ATTR_PAD_FORMAT = "attr_pad_format"; | |||
// Varible | |||
const std::string VAR_ATTR_FORMAT = "_var_format"; | |||
const std::string VAR_ATTR_NAME = "var_name"; | |||
const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; | |||
const std::string VAR_ATTR_4D_FORMAT = "4D"; | |||
const std::string VAR_ATTR_5D_FORMAT = "5D"; | |||
const std::string VAR_ATTR_DATA_TYPE = "data_format"; | |||
const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; | |||
const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; | |||
const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; | |||
const std::string VAR_ATTR_SHAPE = "shape"; | |||
const std::string HALF_VAR_NAME_END = "_fp16"; | |||
const std::string VAR_ATTR_INITED = "var_is_inited"; | |||
const std::string VAR_ATTR_CONTAINER = "container"; | |||
const std::string VAR_ATTR_SHARED_NAME = "shared_name"; | |||
const std::string VAR_ATTR_DTYPE = "dtype"; | |||
const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||
const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; | |||
const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||
const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||
const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||
const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||
// Assign | |||
const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; | |||
// space2bacth batch2space | |||
const std::string BATCH_SPACE_ATTR_BLOCK = "block"; | |||
const std::string BATCH_SPACE_ATTR_PADDING = "padding"; | |||
// depth_to_space space_to_depth | |||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||
// FakeQuantWithMinMaxVars | |||
const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; | |||
const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; | |||
// mobilenet_ssd_conv_fusion | |||
const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; | |||
const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; | |||
const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; | |||
// lsh project | |||
const std::string LSH_PROJ_TYPE = "lsh_project_type"; | |||
// log time stamp | |||
const std::string LOG_TIME_STAMP_LOGID = "logid"; | |||
const std::string LOG_TIME_STAMP_NOTIFY = "notify"; | |||
// ShapeN | |||
const std::string SHAPEN_ATTR_N = "N"; | |||
const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; | |||
const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; | |||
// GatherV2 attr def | |||
const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; | |||
const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; | |||
const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; | |||
// Reshape attr def | |||
const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; | |||
const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; | |||
// axis attr def | |||
const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; | |||
const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; | |||
const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; | |||
const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; | |||
// For constant folding | |||
const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; | |||
const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | |||
const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC = "continuous_input_alloc"; | |||
const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | |||
const std::string ATTR_NAME_REFERENCE = "reference"; | |||
@@ -694,6 +893,8 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; | |||
const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | |||
const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | |||
const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | |||
const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | |||
const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | |||
const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | |||
const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | |||
@@ -705,6 +906,7 @@ const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; | |||
// Function Op | |||
const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; | |||
const std::string ATTR_NAME_PARENT_CONST_TYPE = "_parent_const_type"; | |||
// Used for mark the active node is for loop, type:bool | |||
const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | |||
@@ -719,6 +921,7 @@ const std::string MODEL_ATTR_SESSION_ID = "session_id"; | |||
// l1 fusion and other fusion in future | |||
const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | |||
const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; | |||
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | |||
const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | |||
const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | |||
@@ -730,6 +933,9 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 | |||
const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | |||
const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | |||
const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | |||
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; | |||
const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; | |||
const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; | |||
// Atomic addr clean attrs | |||
const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | |||
@@ -748,6 +954,8 @@ const std::string ATTR_NEED_COMPILE = "_node_need_compile"; | |||
const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; | |||
const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; | |||
// For inserted op | |||
const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | |||
@@ -764,7 +972,22 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou | |||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | |||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | |||
// functional ops attr | |||
const std::string ATTR_NAME_WHILE_COND = "cond"; | |||
const std::string ATTR_NAME_WHILE_BODY = "body"; | |||
// used for label switch | |||
const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | |||
const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | |||
const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||
const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||
// used for LX tiling | |||
const std::string ATTR_NAME_OP_L1_SPACE = "_l1_space"; | |||
const std::string ATTR_NAME_FUSION_TYPE_LIST = "_fusion_type_list"; | |||
const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_list_list"; | |||
const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; | |||
const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | |||
const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | |||
} // namespace ge |
@@ -31,19 +31,18 @@ using std::string; | |||
using std::vector; | |||
namespace ge { | |||
GeAttrValue::NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | |||
NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | |||
GeAttrValue::NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) | |||
: named_attrs_(owner, proto_msg) {} | |||
NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) : named_attrs_(owner, proto_msg) {} | |||
void GeAttrValue::NamedAttrs::SetName(const std::string &name) { | |||
void NamedAttrs::SetName(const std::string &name) { | |||
auto proto_msg = named_attrs_.GetProtoMsg(); | |||
if (proto_msg != nullptr) { | |||
proto_msg->set_name(name); | |||
} | |||
} | |||
string GeAttrValue::NamedAttrs::GetName() const { | |||
string NamedAttrs::GetName() const { | |||
auto proto_msg = named_attrs_.GetProtoMsg(); | |||
if (proto_msg != nullptr) { | |||
return proto_msg->name(); | |||
@@ -51,13 +50,13 @@ string GeAttrValue::NamedAttrs::GetName() const { | |||
return string(); | |||
} | |||
GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { | |||
GeAttrValue NamedAttrs::GetItem(const string &key) const { | |||
GeAttrValue value; | |||
GetAttr(key, value); | |||
(void)GetAttr(key, value); | |||
return value; | |||
} | |||
ProtoAttrMapHelper GeAttrValue::NamedAttrs::MutableAttrMap() { | |||
ProtoAttrMapHelper NamedAttrs::MutableAttrMap() { | |||
auto proto_msg = named_attrs_.GetProtoMsg(); | |||
if (proto_msg != nullptr) { | |||
return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); | |||
@@ -65,7 +64,7 @@ ProtoAttrMapHelper GeAttrValue::NamedAttrs::MutableAttrMap() { | |||
return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); | |||
} | |||
ConstProtoAttrMapHelper GeAttrValue::NamedAttrs::GetAttrMap() const { | |||
ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const { | |||
auto proto_msg = named_attrs_.GetProtoMsg(); | |||
if (proto_msg != nullptr) { | |||
return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); | |||
@@ -515,7 +514,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAtt | |||
return true; | |||
} | |||
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NamedAttrs &value) { | |||
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) { | |||
if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | |||
return false; | |||
} | |||
@@ -528,7 +527,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue: | |||
return true; | |||
} | |||
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NamedAttrs> &value) { | |||
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NAMED_ATTRS> &value) { | |||
if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, | |||
proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { | |||
return false; | |||
@@ -739,7 +738,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
} | |||
bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | |||
GeAttrValue::NamedAttrs &value) { | |||
GeAttrValue::NAMED_ATTRS &value) { | |||
if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | |||
return false; | |||
} | |||
@@ -752,7 +751,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
} | |||
bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | |||
vector<GeAttrValue::NamedAttrs> &value) { | |||
vector<GeAttrValue::NAMED_ATTRS> &value) { | |||
value.clear(); | |||
if (!AttrUtilsHelper::GetValueCheckListType( | |||
proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { | |||
@@ -760,7 +759,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
} | |||
auto &list = proto_attr_val.list(); | |||
for (const auto &item : list.na()) { | |||
value.emplace_back(GeAttrValue::NamedAttrs()); | |||
value.emplace_back(GeAttrValue::NAMED_ATTRS()); | |||
if (value.empty()) { | |||
return false; | |||
} | |||
@@ -967,7 +966,7 @@ ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) | |||
ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) | |||
ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) | |||
ATTR_UTILS_SET_IMP(Tensor, GeTensor) | |||
ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NamedAttrs) | |||
ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) | |||
ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) | |||
ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) | |||
ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>) | |||
@@ -982,7 +981,7 @@ ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector<GeTensorDesc>) | |||
ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensorPtr>) | |||
ATTR_UTILS_SET_IMP(ListTensor, vector<ConstGeTensorPtr>) | |||
ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>) | |||
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NamedAttrs>) | |||
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>) | |||
ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>) | |||
ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>) | |||
ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) | |||
@@ -83,6 +83,12 @@ size_t GeShape::GetDimNum() const { | |||
auto proto_msg = shape_def_.GetProtoMsg(); | |||
if (proto_msg != nullptr) { | |||
if (proto_msg->dim_size() >= 0) { | |||
// check whether contain -2, if true, return -1 | |||
for (auto i : proto_msg->dim()) { | |||
if (i == UNKNOWN_DIM_NUM) { | |||
return 0; | |||
} | |||
} | |||
return proto_msg->dim_size(); | |||
} else { | |||
return 0; | |||
@@ -157,6 +163,10 @@ int64_t GeShape::GetShapeSize() const { | |||
return 0; | |||
} | |||
for (auto i : proto_msg->dim()) { | |||
// if unknown shape, return -1 | |||
if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) { | |||
return UNKNOWN_DIM; | |||
} | |||
res *= i; | |||
} | |||
} | |||
@@ -209,6 +219,7 @@ const string TENSOR_UTILS_RC = "rc"; | |||
const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; | |||
const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; | |||
const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; | |||
const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; | |||
GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} | |||
@@ -396,6 +407,35 @@ GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); } | |||
void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); } | |||
// set shape with -2, it stand for unknown shape | |||
void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); } | |||
// for unknown shape | |||
graphStatus GeTensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) { | |||
std::vector<vector<int64_t>> shape_range; | |||
for (const auto &ele : range) { | |||
shape_range.emplace_back(std::vector<int64_t>({ele.first, ele.second})); | |||
} | |||
auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); | |||
return ret ? GRAPH_SUCCESS : GRAPH_FAILED; | |||
} | |||
graphStatus GeTensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const { | |||
std::vector<vector<int64_t>> shape_range; | |||
(void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); | |||
for (const auto &ele : shape_range) { | |||
// here must be only two elemenet because pair | |||
if (ele.size() != 2) { | |||
GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size()); | |||
return GRAPH_FAILED; | |||
} | |||
std::pair<int64_t, int64_t> pair({ele[0], ele[1]}); | |||
range.push_back(pair); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
GeShape GeTensorDesc::GetOriginShape() const { | |||
vector<int64_t> origin_shape; | |||
if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) { | |||
@@ -16,11 +16,12 @@ | |||
#include "external/graph/graph.h" | |||
#include "debug/ge_util.h" | |||
#include "external/graph/operator.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/ge_attr_value.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/debug/ge_op_types.h" | |||
#include "graph/model.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
using std::map; | |||
using std::pair; | |||
@@ -214,6 +215,23 @@ class GraphImpl { | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const { | |||
for (auto &op : op_list_) { | |||
auto op_type = op.second.GetOpType(); | |||
if (op_type == type) { | |||
ops.push_back(op.second); | |||
continue; | |||
} | |||
if (op_type == ge::FRAMEWORKOP) { | |||
op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type); | |||
if (op_type == type) { | |||
ops.push_back(op.second); | |||
} | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
void SetNeedIteration(bool need_iteration) { | |||
if (compute_graph_ == nullptr) { | |||
GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); | |||
@@ -222,6 +240,8 @@ class GraphImpl { | |||
compute_graph_->SetNeedIteration(need_iteration); | |||
} | |||
const std::string &GetName() const { return name_; } | |||
private: | |||
std::string name_; | |||
std::string output_name_; | |||
@@ -255,6 +275,11 @@ graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { | |||
return impl_->FindOpByName(name, op); | |||
} | |||
graphStatus Graph::FindOpByType(const string &type, std::vector<ge::Operator> &ops) const { | |||
GE_CHECK_NOTNULL(impl_); | |||
return impl_->FindOpByType(type, ops); | |||
} | |||
Graph &Graph::SetInputs(const vector<ge::Operator> &inputs) { | |||
GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") | |||
GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); | |||
@@ -331,6 +356,8 @@ graphStatus Graph::LoadFromFile(const string &file_name) { | |||
return GRAPH_SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); } | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph | |||
GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { | |||
GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); | |||
@@ -343,4 +370,15 @@ GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) | |||
return graph; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { | |||
GE_CHECK_NOTNULL(graph.impl_); | |||
GE_CHECK_NOTNULL(graph.impl_->compute_graph_); | |||
graph.impl_->op_list_.clear(); | |||
for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { | |||
graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -16,7 +16,10 @@ | |||
#include "graph/model_serialize.h" | |||
#include <google/protobuf/text_format.h> | |||
#include <queue> | |||
#include <iostream> | |||
#include "debug/ge_attr_define.h" | |||
#include "debug/ge_log.h" | |||
#include "debug/ge_util.h" | |||
@@ -26,6 +29,7 @@ | |||
#include "utils/graph_utils.h" | |||
#include "debug/ge_op_types.h" | |||
using std::map; | |||
using std::string; | |||
namespace ge { | |||
@@ -121,6 +125,11 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||
} | |||
} | |||
} | |||
op_def_proto->set_id(op_desc->GetId()); | |||
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||
op_def_proto->add_subgraph_name(name); | |||
} | |||
} | |||
return true; | |||
} | |||
@@ -196,6 +205,14 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode | |||
GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | |||
return false; | |||
} | |||
for (auto subgraph : compute_graph->GetAllSubgraphs()) { | |||
if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) { | |||
GELOGE(GRAPH_FAILED, "Serialize subgraph failed"); | |||
return false; | |||
} | |||
} | |||
return true; | |||
} | |||
@@ -228,6 +245,14 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d | |||
GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | |||
op_desc->outputs_desc_.push_back(temp_value); | |||
} | |||
op_desc->SetId(op_def_proto.id()); | |||
uint32_t graph_index = 0; | |||
for (const std::string &name : op_def_proto.subgraph_name()) { | |||
op_desc->AddSubgraphName(name); | |||
op_desc->SetSubgraphInstanceName(graph_index++, name); | |||
} | |||
return true; | |||
} | |||
@@ -238,7 +263,7 @@ bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op | |||
GELOGW("UnserializeOpDesc error."); | |||
} | |||
NodePtr node = graph->AddNode(op_desc); | |||
NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); | |||
GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); | |||
// Inputs | |||
@@ -319,6 +344,35 @@ bool ModelSerializeImp::HandleNodeNameRef() { | |||
return true; | |||
} | |||
bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map<string, ComputeGraphPtr> &subgraphs) { | |||
std::queue<ComputeGraphPtr> all_graphs; | |||
all_graphs.emplace(compute_graph); | |||
while (!all_graphs.empty()) { | |||
ComputeGraphPtr graph = all_graphs.front(); | |||
all_graphs.pop(); | |||
for (const NodePtr &node : graph->GetDirectNode()) { | |||
const OpDescPtr op_desc = node->GetOpDesc(); | |||
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||
auto it = subgraphs.find(name); | |||
if (it == subgraphs.end()) { | |||
GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(), | |||
subgraphs.size()); | |||
return false; | |||
} | |||
ComputeGraphPtr &subgraph = it->second; | |||
subgraph->SetParentGraph(graph); | |||
subgraph->SetParentNode(node); | |||
compute_graph->AddSubgraph(subgraph->GetName(), subgraph); | |||
all_graphs.emplace(subgraph); | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { | |||
model.name_ = model_proto.name(); | |||
model.version_ = model_proto.version(); | |||
@@ -332,7 +386,31 @@ bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_pr | |||
if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { | |||
model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); | |||
} | |||
// 0 is main graph, following is subgraph. | |||
map<string, ComputeGraphPtr> subgraphs; | |||
for (int idx = 1; idx < graphs_proto.size(); ++idx) { | |||
ComputeGraphPtr subgraph; | |||
ModelSerializeImp impl; | |||
if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) { | |||
GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed"); | |||
return false; | |||
} | |||
if (!impl.HandleNodeNameRef()) { | |||
GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); | |||
return false; | |||
} | |||
subgraphs[subgraph->GetName()] = subgraph; | |||
} | |||
if (!RebuildOwnership(compute_graph_ptr, subgraphs)) { | |||
GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed"); | |||
return false; | |||
} | |||
} | |||
if (!HandleNodeNameRef()) { | |||
GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); | |||
return false; | |||
@@ -61,6 +61,8 @@ const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; | |||
const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | |||
const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; | |||
const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; | |||
const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; | |||
@@ -227,6 +229,40 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp | |||
} | |||
} | |||
graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { | |||
auto input_name_idx = GetAllInputName(); | |||
for (unsigned int i = 0; i < num; i++) { | |||
string input_name = name + std::to_string(i); | |||
GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, | |||
"Add input tensor_desc is existed. name[%s]", input_name.c_str()); | |||
std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | |||
if (in_desc == nullptr) { | |||
GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); | |||
return GRAPH_FAILED; | |||
} | |||
if (index > inputs_desc_.size()) { | |||
GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); | |||
return GRAPH_FAILED; | |||
} | |||
(void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); | |||
// Update index in input_name_idx | |||
for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { | |||
if (it->second >= (index + i)) { | |||
it->second += 1; | |||
} | |||
} | |||
(void)input_name_idx.insert(make_pair(input_name, i + index)); | |||
} | |||
SetAllInputName(input_name_idx); | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | |||
auto input_name_idx = GetAllInputName(); | |||
for (unsigned int i = 0; i < num; i++) { | |||
@@ -239,7 +275,6 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n | |||
GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); | |||
return GRAPH_FAILED; | |||
} | |||
(void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | |||
// Update index in input_name_idx | |||
@@ -634,6 +669,13 @@ graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int n | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) { | |||
if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { | |||
return GRAPH_FAILED; | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { | |||
if (is_push_back) { | |||
for (unsigned int i = 0; i < num; i++) { | |||
@@ -1054,6 +1096,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetDstName | |||
return dst_name; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector<string> &depend_names) { | |||
auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); | |||
if (ret != true) { | |||
GELOGE(GRAPH_FAILED, "set op_infer_depends fail."); | |||
} | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetOpInferDepends() const { | |||
vector<string> depend_names; | |||
(void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); | |||
return depend_names; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector<int64_t> &dst_index) { | |||
auto proto_msg = op_def_.GetProtoMsg(); | |||
if (proto_msg != nullptr) { | |||
@@ -1199,20 +1254,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &O | |||
return subgraph_instance_names_; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AddSubgraphInstanceName(std::string name) { | |||
subgraph_instance_names_.emplace_back(std::move(name)); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { | |||
for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { | |||
if (*iter == name) { | |||
subgraph_instance_names_.erase(iter); | |||
*iter = ""; | |||
return; | |||
} | |||
} | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { | |||
GELOGI("Add subgraph name is %s", name.c_str()); | |||
auto iter = subgraph_names_to_index_.find(name); | |||
if (iter != subgraph_names_to_index_.end()) { | |||
GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); | |||
@@ -1220,6 +1272,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphNa | |||
} | |||
auto size = subgraph_names_to_index_.size(); | |||
subgraph_names_to_index_[name] = size; | |||
subgraph_instance_names_.resize(size + 1); | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -1227,4 +1280,34 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint3 | |||
const { | |||
return subgraph_names_to_index_; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index, | |||
const std::string &name) { | |||
GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index); | |||
if (index >= subgraph_instance_names_.size()) { | |||
GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
subgraph_instance_names_[index] = name; | |||
return GRAPH_SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RegisterSubgraphIrName(const string &name, | |||
SubgraphType type) { | |||
subgraph_ir_names_to_type_[name] = type; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, SubgraphType> &OpDesc::GetSubgraphIrNames() | |||
const { | |||
return subgraph_ir_names_to_type_; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY SubgraphType | |||
OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { | |||
auto iter = subgraph_ir_names_to_type_.find(name); | |||
if (iter == subgraph_ir_names_to_type_.end()) { | |||
return kSubgraphTypeEnd; | |||
} | |||
return iter->second; | |||
} | |||
} // namespace ge |
@@ -15,6 +15,7 @@ | |||
*/ | |||
#include "external/graph/operator.h" | |||
#include "external/graph/operator_factory.h" | |||
#include <stdint.h> | |||
#include <algorithm> | |||
#include <mutex> | |||
@@ -38,6 +39,11 @@ | |||
#include "utils/tensor_adapter.h" | |||
#include "utils/tensor_utils.h" | |||
#include "utils/type_utils.h" | |||
#include <algorithm> | |||
#include <mutex> | |||
#include <queue> | |||
#include <set> | |||
#include <stdint.h> | |||
using std::enable_shared_from_this; | |||
using std::make_pair; | |||
@@ -343,15 +349,71 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
InferenceContextPtr GetInferenceContext() const { return inference_context_; } | |||
void SubgraphRegister(const std::string &name, bool dynamic) { | |||
op_desc_->RegisterSubgraphIrName(name, dynamic ? kDynamic : kStatic); | |||
} | |||
void SubgraphCountRegister(const std::string &name, uint32_t count) { | |||
if (op_desc_->GetSubgraphTypeByIrName(name) == kStatic) { | |||
op_desc_->AddSubgraphName(name); | |||
} else { | |||
for (uint32_t i = 0; i < count; ++i) { | |||
op_desc_->AddSubgraphName(name + std::to_string(i)); | |||
} | |||
} | |||
subgraph_names_to_builders_[name].resize(count, nullptr); | |||
} | |||
void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { | |||
auto iter = subgraph_names_to_builders_.find(name); | |||
if (iter == subgraph_names_to_builders_.end()) { | |||
GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, invalid name", name.c_str(), index); | |||
return; | |||
} | |||
if (iter->second.size() <= index) { | |||
GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, excceds the max size %zu", | |||
name.c_str(), index, iter->second.size()); | |||
return; | |||
} | |||
iter->second[index] = builder; | |||
} | |||
SubgraphBuilder GetSubgraphBuilder(const std::string &name, uint32_t index) const { | |||
auto iter = subgraph_names_to_builders_.find(name); | |||
if (iter == subgraph_names_to_builders_.end()) { | |||
GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, invalid name", name.c_str(), index); | |||
return nullptr; | |||
} | |||
if (iter->second.size() <= index) { | |||
GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, excceds the max size %zu", | |||
name.c_str(), index, iter->second.size()); | |||
return nullptr; | |||
} | |||
return iter->second[index]; | |||
} | |||
std::vector<std::string> GetSubgraphNames() const { | |||
std::vector<std::string> names; | |||
for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) { | |||
names.emplace_back(subgraph_name_to_type.first); | |||
} | |||
return names; | |||
} | |||
size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); } | |||
OpDescPtr op_desc_ = nullptr; | |||
private: | |||
ge::ConstNodePtr node_{nullptr}; | |||
ge::InferenceContextPtr inference_context_; | |||
GraphBuilderCallback graph_builder_callback_; | |||
std::map<string, std::vector<OpIO>> output_links_{}; | |||
std::map<string, OpIO> input_link_{}; | |||
std::vector<std::weak_ptr<OperatorImpl>> control_input_link_{}; | |||
std::vector<std::weak_ptr<OperatorImpl>> control_output_link_{}; | |||
std::map<std::string, std::vector<SubgraphBuilder>> subgraph_names_to_builders_; | |||
}; | |||
// Used to manage OperatorImpl instances created by ge api. | |||
@@ -559,7 +621,6 @@ InferenceContextPtr Operator::GetInferenceContext() const { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); | |||
return operator_impl_->GetInferenceContext(); | |||
} | |||
TensorDesc Operator::GetInputDesc(uint32_t index) const { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); | |||
return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); | |||
@@ -698,7 +759,7 @@ const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() con | |||
void Operator::InputRegister(const string &name) { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); | |||
(void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); | |||
} | |||
void Operator::OptionalInputRegister(const string &name) { | |||
@@ -745,6 +806,12 @@ void Operator::DynamicInputRegister(const string &name, const unsigned int num, | |||
(void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); | |||
} | |||
void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) { | |||
GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr."); | |||
operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index); | |||
} | |||
int Operator::GetDynamicInputNum(const string &name) const { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | |||
@@ -896,6 +963,11 @@ OP_ATTR_GET_IMP(string &, Str) | |||
OP_ATTR_SET_IMP(const vector<string> &, ListStr) | |||
OP_ATTR_GET_IMP(vector<string> &, ListStr) | |||
OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||
OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||
OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||
OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||
OP_ATTR_REG_IMP(int64_t, Int) | |||
OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt) | |||
OP_ATTR_REG_IMP(float, Float) | |||
@@ -905,6 +977,8 @@ OP_ATTR_REG_IMP(const vector<string> &, ListStr) | |||
OP_ATTR_REG_IMP(bool, Bool) | |||
OP_ATTR_REG_IMP(const vector<bool> &, ListBool) | |||
OP_ATTR_REG_IMP(const vector<vector<int64_t>> &, ListListInt) | |||
OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||
OP_ATTR_REG_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||
#undef OP_ATTR_SET_IMP | |||
#undef OP_ATTR_GET_IMP | |||
@@ -1114,6 +1188,95 @@ void Operator::AttrRegister(const string &name, const OpBytes &attr_value) { | |||
} | |||
} | |||
void Operator::SubgraphRegister(const std::string &name, bool dynamic) { | |||
if (operator_impl_ == nullptr) { | |||
GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||
return; | |||
} | |||
operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic); | |||
} | |||
void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { | |||
if (operator_impl_ == nullptr) { | |||
GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||
return; | |||
} | |||
operator_impl_->SubgraphCountRegister(name, count); | |||
} | |||
void Operator::SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { | |||
if (operator_impl_ == nullptr) { | |||
GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||
return; | |||
} | |||
operator_impl_->SetSubgraphBuilder(name, index, builder); | |||
} | |||
std::vector<std::string> Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } | |||
SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &name, uint32_t index) const { | |||
if (operator_impl_ == nullptr) { | |||
GELOGE(GRAPH_FAILED, "operator impl is nullptr."); | |||
return nullptr; | |||
} | |||
return operator_impl_->GetSubgraphBuilder(name, index); | |||
} | |||
SubgraphBuilder Operator::GetSubgraphBuilder(const string &name) const { return GetDynamicSubgraphBuilder(name, 0); } | |||
Graph Operator::GetSubgraph(const string &name) const { | |||
if (operator_impl_ == nullptr) { | |||
GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str()); | |||
return Graph(""); | |||
} | |||
auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); | |||
if (op_desc == nullptr) { | |||
GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str()); | |||
return Graph(""); | |||
} | |||
const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); | |||
auto iter = subgraph_names_to_index.find(name); | |||
if (iter == subgraph_names_to_index.end()) { | |||
GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str()); | |||
return Graph(""); | |||
} | |||
auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); | |||
if (subgraph_instance_name.empty()) { | |||
GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second); | |||
return Graph(""); | |||
} | |||
auto node = operator_impl_->GetNode(); | |||
if (node == nullptr) { | |||
GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str()); | |||
return Graph(""); | |||
} | |||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
if (root_graph == nullptr) { | |||
GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str()); | |||
return Graph(""); | |||
} | |||
auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); | |||
if (subgraph == nullptr) { | |||
GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(), | |||
iter->second, subgraph_instance_name.c_str()); | |||
return Graph(""); | |||
} | |||
return GraphUtils::CreateGraphFromComputeGraph(subgraph); | |||
} | |||
Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const { | |||
return GetSubgraph(name + std::to_string(index)); | |||
} | |||
size_t Operator::GetSubgraphNamesCount() const { | |||
if (operator_impl_ == nullptr) { | |||
GE_LOGE("Failed to get subgraph names count, the operator impl is null"); | |||
return 0; | |||
} | |||
return operator_impl_->GetSubgraphNamesCount(); | |||
} | |||
class GraphBuilderImpl { | |||
public: | |||
explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared<ComputeGraph>(name)) { | |||
@@ -96,7 +96,6 @@ VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) | |||
graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { | |||
if (operator_creators_ == nullptr) { | |||
GELOGI("operator_creators_ init"); | |||
operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>()); | |||
} | |||
auto it = operator_creators_->find(operator_type); | |||
@@ -0,0 +1,422 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/ref_relation.h" | |||
#include <unordered_set> | |||
#include <unordered_map> | |||
#include "utils/mem_utils.h" | |||
#include "debug/ge_log.h" | |||
#include "debug/ge_op_types.h" | |||
#include "debug/ge_util.h" | |||
#include "debug/ge_attr_define.h" | |||
#include "graph/ge_error_codes.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "framework/common/debug/ge_log.h" | |||
using namespace std; | |||
using namespace ge; | |||
namespace ge { | |||
namespace { | |||
const char *kRefIndex = "_parent_node_index"; | |||
const string kWhile = "While"; | |||
const string kIf = "If"; | |||
const string kCase = "Case"; | |||
const int kMaxElementNum = 100; | |||
std::unordered_set<string> function_op = {kWhile, kIf, kCase}; | |||
} // namespace | |||
/* Impl */ | |||
class RefRelations::Impl { | |||
public: | |||
graphStatus LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) { | |||
unsigned long number = static_cast<unsigned long>(reinterpret_cast<uintptr_t>(key.node.get())); | |||
std::string lookup_key = | |||
key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number); | |||
auto iter = look_up_table_.find(lookup_key); | |||
if (iter != look_up_table_.end()) { | |||
for (auto &c : iter->second) { | |||
result.insert(c); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
GELOGW("can not find any relations! key value is %s", lookup_key.c_str()); | |||
return GRAPH_SUCCESS; | |||
}; | |||
graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); | |||
graphStatus Clear() { | |||
GELOGD("Start clear boundary reflections between main graph and sub graph!"); | |||
look_up_table_.clear(); | |||
values_.clear(); | |||
return GRAPH_SUCCESS; | |||
}; | |||
private: | |||
graphStatus BuildLookUpTables(); | |||
graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||
vector<vector<RefCell>> &node_refs); | |||
graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||
vector<vector<RefCell>> &node_refs); | |||
graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node, | |||
const vector<vector<NodePtr>> &classed_data_nodes, | |||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||
vector<vector<RefCell>> &node_refs); | |||
void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes, | |||
vector<NodePtr> &netoutput_nodes, const std::vector<std::string> &sub_graph_names, | |||
const std::string &node_type); | |||
graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph); | |||
graphStatus ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, vector<vector<NodePtr>> &classed_data_nodes); | |||
graphStatus ProcessSubgraphNetoutput(const vector<NodePtr> &netoutput_nodes, | |||
vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes); | |||
std::unordered_map<string, vector<RefCell>> look_up_table_; | |||
std::vector<vector<vector<RefCell>>> values_; | |||
}; | |||
// Node Level | |||
graphStatus RefRelations::Impl::BuildRefRelationsForBranch( | |||
const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||
GELOGD("Enter BuildRefRelationsForBranch!"); | |||
size_t ref_i = 0; | |||
for (const auto &ref_i_data_nodes : classed_data_nodes) { | |||
vector<RefCell> in_ref_i_all_refs; | |||
RefCell cell_root; | |||
cell_root.node_name = root_node->GetName(); | |||
cell_root.node = root_node; | |||
cell_root.in_out = NODE_IN; | |||
cell_root.in_out_idx = ref_i; | |||
in_ref_i_all_refs.emplace_back(cell_root); | |||
for (const auto &data : ref_i_data_nodes) { | |||
RefCell cell_in; | |||
RefCell cell_out; | |||
cell_in.node_name = data->GetName(); | |||
cell_in.node = data; | |||
cell_in.in_out = NODE_IN; | |||
cell_in.in_out_idx = 0; | |||
cell_out.node_name = data->GetName(); | |||
cell_out.node = data; | |||
cell_out.in_out = NODE_OUT; | |||
cell_out.in_out_idx = 0; | |||
in_ref_i_all_refs.emplace_back(cell_in); | |||
in_ref_i_all_refs.emplace_back(cell_out); | |||
} | |||
node_refs.emplace_back(in_ref_i_all_refs); | |||
ref_i++; | |||
} | |||
size_t ref_o = 0; | |||
for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { | |||
vector<RefCell> out_ref_i_all_refs; | |||
RefCell cell_root; | |||
cell_root.node_name = root_node->GetName(); | |||
cell_root.node = root_node; | |||
cell_root.in_out = NODE_OUT; | |||
cell_root.in_out_idx = ref_o; | |||
out_ref_i_all_refs.emplace_back(cell_root); | |||
for (const auto &ele : ref_o_net_nodes) { | |||
RefCell cell_netoutput_in; | |||
RefCell cell_netoutput_out; | |||
cell_netoutput_in.node_name = (ele.first)->GetName(); | |||
cell_netoutput_in.node = ele.first; | |||
cell_netoutput_in.in_out = NODE_IN; | |||
cell_netoutput_in.in_out_idx = ele.second; | |||
cell_netoutput_out.node_name = (ele.first)->GetName(); | |||
cell_netoutput_out.node = ele.first; | |||
cell_netoutput_out.in_out = NODE_OUT; | |||
cell_netoutput_out.in_out_idx = ele.second; | |||
out_ref_i_all_refs.emplace_back(cell_netoutput_in); | |||
out_ref_i_all_refs.emplace_back(cell_netoutput_out); | |||
} | |||
node_refs.emplace_back(out_ref_i_all_refs); | |||
ref_o++; | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus RefRelations::Impl::BuildLookUpTables() { | |||
for (size_t i = 0; i < values_.size(); i++) { | |||
vector<vector<RefCell>> &val = values_[i]; | |||
for (const auto &ele : val) { | |||
for (const auto &ref_cell : ele) { | |||
string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) + | |||
std::to_string(static_cast<unsigned long>(reinterpret_cast<uintptr_t>(ref_cell.node.get()))); | |||
look_up_table_[key] = ele; | |||
} | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus RefRelations::Impl::BuildRefRelationsForWhile( | |||
const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||
GELOGD("Enter BuildRefRelations for while op!"); | |||
// data_nodes has been sorted | |||
// for while, input num must be same as output num | |||
auto input_num = root_node->GetAllInDataAnchorsSize(); | |||
size_t ref_i = 0; | |||
while (ref_i < input_num) { | |||
auto &ref_i_data_nodes = classed_data_nodes[ref_i]; | |||
auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; | |||
vector<RefCell> ref_i_all_refs; | |||
RefCell cell_root_i; | |||
RefCell cell_root_o; | |||
cell_root_i.node_name = root_node->GetName(); | |||
cell_root_i.node = root_node; | |||
cell_root_i.in_out = NODE_IN; | |||
cell_root_i.in_out_idx = ref_i; | |||
ref_i_all_refs.emplace_back(cell_root_i); | |||
cell_root_o.node_name = root_node->GetName(); | |||
cell_root_o.node = root_node; | |||
cell_root_o.in_out = NODE_OUT; | |||
cell_root_o.in_out_idx = ref_i; | |||
ref_i_all_refs.emplace_back(cell_root_o); | |||
for (const auto &data : ref_i_data_nodes) { | |||
RefCell cell_in; | |||
RefCell cell_out; | |||
cell_in.node_name = data->GetName(); | |||
cell_in.node = data; | |||
cell_in.in_out = NODE_IN; | |||
cell_in.in_out_idx = 0; | |||
cell_out.node_name = data->GetName(); | |||
cell_out.node = data; | |||
cell_out.in_out = NODE_OUT; | |||
cell_out.in_out_idx = 0; | |||
ref_i_all_refs.emplace_back(cell_in); | |||
ref_i_all_refs.emplace_back(cell_out); | |||
} | |||
for (const auto &ele : ref_i_net_nodes) { | |||
RefCell cell_netoutput_in; | |||
RefCell cell_netoutput_out; | |||
cell_netoutput_in.node_name = (ele.first)->GetName(); | |||
cell_netoutput_in.node = ele.first; | |||
cell_netoutput_in.in_out = NODE_IN; | |||
cell_netoutput_in.in_out_idx = ele.second; | |||
cell_netoutput_out.node_name = (ele.first)->GetName(); | |||
cell_netoutput_out.node = ele.first; | |||
cell_netoutput_out.in_out = NODE_OUT; | |||
cell_netoutput_out.in_out_idx = ele.second; | |||
ref_i_all_refs.emplace_back(cell_netoutput_in); | |||
ref_i_all_refs.emplace_back(cell_netoutput_out); | |||
} | |||
node_refs.emplace_back(ref_i_all_refs); | |||
ref_i++; | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
// build ref relations according to diff func op type | |||
graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( | |||
const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||
// data_nodes has been sorted | |||
auto node_type = root_node->GetType(); | |||
auto status = GRAPH_SUCCESS; | |||
if (node_type == kIf || node_type == kCase) { | |||
status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||
} else if (node_type == kWhile) { | |||
status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||
} else { | |||
GELOGE(GRAPH_PARAM_INVALID, "Node type [%s] is not supported for build ref relations!", node_type.c_str()); | |||
status = GRAPH_PARAM_INVALID; | |||
} | |||
return status; | |||
} | |||
void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes, | |||
vector<NodePtr> &netoutput_nodes, | |||
const std::vector<std::string> &sub_graph_names, | |||
const std::string &node_type) { | |||
int sub_graph_idx = 0; | |||
for (const auto &name : sub_graph_names) { | |||
auto sub_graph = root_graph.GetSubgraph(name); | |||
for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { | |||
auto sub_graph_node_type = sub_graph_node->GetType(); | |||
if (sub_graph_node_type == DATA) { | |||
data_nodes.emplace_back(sub_graph_node); | |||
} else if (sub_graph_node_type == NETOUTPUT) { | |||
// if while, the first subgraph must be cond subgraph. | |||
// There is no meaning for refs ,so continue | |||
if (node_type == kWhile && sub_graph_idx == 0) { | |||
continue; | |||
} | |||
netoutput_nodes.emplace_back(sub_graph_node); | |||
} | |||
continue; | |||
} | |||
sub_graph_idx++; | |||
} | |||
} | |||
graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) { | |||
auto parent_graph_ptr = graph.GetParentGraph(); | |||
if (parent_graph_ptr == nullptr) { | |||
root_graph = graph; | |||
return GRAPH_SUCCESS; | |||
} | |||
auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); | |||
if (root_graph_ptr == nullptr) { | |||
GE_LOGE("Get null root graph"); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
root_graph = *root_graph_ptr; | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, | |||
vector<vector<NodePtr>> &classed_data_nodes) { | |||
int max_ref_idx = 0; | |||
for (const auto &e : data_nodes) { | |||
int i; | |||
bool is_exist = true; | |||
is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i); | |||
if (!is_exist) { | |||
GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex); | |||
return GRAPH_FAILED; | |||
} | |||
max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; | |||
} | |||
while (!data_nodes.empty()) { | |||
auto data = data_nodes.back(); | |||
data_nodes.pop_back(); | |||
int ref_idx = 0; | |||
(void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); | |||
classed_data_nodes[ref_idx].emplace_back(data); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( | |||
const vector<NodePtr> &netoutput_nodes, vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) { | |||
for (const auto &sub_netoutput_node : netoutput_nodes) { | |||
auto op_desc = sub_netoutput_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { | |||
auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx()); | |||
if (in_desc == nullptr) { | |||
GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it", | |||
sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); | |||
return GRAPH_FAILED; | |||
} | |||
int ref_o; | |||
if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { | |||
if (ref_o >= kMaxElementNum) { | |||
return GRAPH_FAILED; | |||
} | |||
classed_netoutput_nodes[ref_o].emplace_back( | |||
std::pair<NodePtr, size_t>({sub_netoutput_node, static_cast<size_t>(in_data_anchor->GetIdx())})); | |||
} | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { | |||
/* First Step: Get root graph */ | |||
ge::ComputeGraph &root_graph = graph; | |||
auto status = GetRootGraph(graph, root_graph); | |||
if (status != GRAPH_SUCCESS) { | |||
return status; | |||
} | |||
for (const auto &node : graph.GetAllNodes()) { | |||
auto node_type = node->GetType(); | |||
if (function_op.find(node_type) == function_op.end()) { | |||
continue; | |||
} | |||
std::vector<NodePtr> ref_nodes; | |||
auto op_desc = node->GetOpDesc(); | |||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
vector<NodePtr> data_nodes; | |||
vector<NodePtr> netoutput_nodes; | |||
// Get data and netoutput of sub_graph | |||
GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); | |||
vector<vector<NodePtr>> classed_data_nodes(kMaxElementNum); // according to ref_idx | |||
vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(kMaxElementNum); // according to ref_idx | |||
status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); | |||
return status; | |||
} | |||
// for netoutput | |||
// check netoutput | |||
// here main graph output number must be the same as every sub_graph netoutput node | |||
// key: netoutput node_ptr ,<ref_idx, net_in_idx> | |||
status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "process netoutput failed!"); | |||
return status; | |||
} | |||
vector<vector<RefCell>> node_refs; | |||
status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); | |||
return status; | |||
} | |||
if (!node_refs.empty()) { | |||
values_.push_back(node_refs); | |||
} | |||
} | |||
/* Seconde Step: generate map */ | |||
status = BuildLookUpTables(); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(status, "Build look up tables failed!"); | |||
return status; | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
/* Ref Relations Interface */ | |||
RefRelations::RefRelations() { | |||
impl_ = MakeShared<Impl>(); | |||
if (impl_ == nullptr) { | |||
GELOGE(GRAPH_FAILED, "MakeShared failed!"); | |||
return; | |||
} | |||
} | |||
graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) { | |||
GE_CHECK_NOTNULL(impl_); | |||
return impl_->LookUpRefRelations(key, result); | |||
} | |||
graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) { | |||
GE_CHECK_NOTNULL(impl_); | |||
return impl_->BuildRefRelations(root_graph); | |||
} | |||
graphStatus RefRelations::Clear() { | |||
GE_CHECK_NOTNULL(impl_); | |||
return impl_->Clear(); | |||
} | |||
} // namespace ge |
@@ -21,7 +21,7 @@ | |||
#include <unordered_map> | |||
#include <utility> | |||
#include <vector> | |||
#include "framework/common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "debug/ge_log.h" | |||
@@ -37,7 +37,6 @@ | |||
namespace ge { | |||
namespace { | |||
constexpr const char *kRefIndex = "parent_node_index"; | |||
graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
auto op_desc = node->GetOpDesc(); | |||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
@@ -47,6 +46,10 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
for (const auto &name : sub_graph_names) { | |||
if (name.empty()) { | |||
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||
continue; | |||
} | |||
auto sub_graph = root_graph->GetSubgraph(name); | |||
if (sub_graph == nullptr) { | |||
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
@@ -63,7 +66,7 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) { | |||
if (!AttrUtils::GetInt(node_sub->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||
node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
@@ -76,7 +79,10 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | |||
return GRAPH_FAILED; | |||
} | |||
GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), | |||
node->GetName().c_str()); | |||
auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | |||
if (ret != GRAPH_SUCCESS) { | |||
GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", | |||
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||
@@ -101,6 +107,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
for (const auto &name : sub_graph_names) { | |||
if (name.empty()) { | |||
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||
continue; | |||
} | |||
auto sub_graph = root_graph->GetSubgraph(name); | |||
if (sub_graph == nullptr) { | |||
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
@@ -132,11 +142,14 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
node->GetName().c_str(), edge_anchor->GetIdx()); | |||
return GRAPH_FAILED; | |||
} | |||
GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(), | |||
edge_desc->GetShape().GetDimNum()); | |||
int ref_i; | |||
if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) { | |||
if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
// if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||
continue; | |||
} | |||
GELOGI("Parent node index of edge desc is %d", ref_i); | |||
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | |||
if (output_desc == nullptr) { | |||
GE_LOGE( | |||
@@ -29,6 +29,7 @@ namespace { | |||
/// Extra 1 byte store '\0' | |||
const int EXTRA_STORE_POINTER_FOR_STRING = 8; | |||
const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; | |||
const int64_t UNKNOWN_DIM_SIZE = -1; | |||
} // namespace | |||
namespace ge { | |||
@@ -65,6 +66,7 @@ class TensorDescImpl { | |||
TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} | |||
Shape shape_; | |||
std::vector<std::pair<int64_t, int64_t>> range_; | |||
Format format_ = FORMAT_ND; | |||
Format origin_format_ = FORMAT_ND; | |||
DataType data_type_ = DT_FLOAT; | |||
@@ -94,7 +96,16 @@ class ShapeImpl { | |||
public: | |||
ShapeImpl() = default; | |||
~ShapeImpl() = default; | |||
explicit ShapeImpl(const std::vector<int64_t> &dims) : dims_(dims) {} | |||
explicit ShapeImpl(const std::vector<int64_t> &dims) { | |||
bool is_unknown_dim_num = false; | |||
for (const auto &dim : dims) { | |||
if (dim == UNKNOWN_DIM_NUM) { | |||
is_unknown_dim_num = true; | |||
break; | |||
} | |||
} | |||
dims_ = is_unknown_dim_num ? std::vector<int64_t>({UNKNOWN_DIM_NUM}) : dims; | |||
} | |||
std::vector<int64_t> dims_; | |||
}; | |||
@@ -105,6 +116,11 @@ Shape::Shape(const std::vector<int64_t> &dims) { impl_ = ComGraphMakeShared<Shap | |||
size_t Shape::GetDimNum() const { | |||
if (impl_ != nullptr) { | |||
for (auto i : impl_->dims_) { | |||
if (i == UNKNOWN_DIM_NUM) { | |||
return 0; | |||
} | |||
} | |||
return impl_->dims_.size(); | |||
} | |||
return 0; | |||
@@ -146,6 +162,10 @@ int64_t Shape::GetShapeSize() const { | |||
} | |||
int64_t size = 1; | |||
for (auto i : impl_->dims_) { | |||
if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) { | |||
return UNKNOWN_DIM_SIZE; | |||
} | |||
if (!Int64MulNotOverflow(size, i)) { | |||
GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); | |||
size = 0; | |||
@@ -217,6 +237,34 @@ void TensorDesc::SetShape(const Shape &shape) { | |||
} | |||
} | |||
// set shape with -2, it stand for unknown shape | |||
graphStatus TensorDesc::SetUnknownDimNumShape() { | |||
if (impl != nullptr) { | |||
impl->shape_ = Shape({UNKNOWN_DIM_NUM}); | |||
return GRAPH_SUCCESS; | |||
} | |||
GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!"); | |||
return GRAPH_FAILED; | |||
} | |||
// for unknown shape | |||
graphStatus TensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) { | |||
if (impl != nullptr) { | |||
impl->range_ = range; | |||
return GRAPH_SUCCESS; | |||
} | |||
GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!"); | |||
return GRAPH_FAILED; | |||
} | |||
graphStatus TensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const { | |||
if (impl != nullptr) { | |||
range = impl->range_; | |||
return GRAPH_SUCCESS; | |||
} | |||
GELOGE(GRAPH_FAILED, "impl is nullptr!"); | |||
return GRAPH_FAILED; | |||
} | |||
Shape TensorDesc::GetOriginShape() const { | |||
if (impl != nullptr) { | |||
return impl->origin_shape_; | |||
@@ -541,6 +589,17 @@ GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_des | |||
tensor_desc.GetDataType()); | |||
ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | |||
ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | |||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
auto status = tensor_desc.GetShapeRange(shape_range); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Get shape range failed!"); | |||
return ge_tensor_desc; | |||
} | |||
status = ge_tensor_desc.SetShapeRange(shape_range); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Set shape range failed!"); | |||
return ge_tensor_desc; | |||
} | |||
auto size = tensor_desc.GetSize(); | |||
TensorUtils::SetSize(ge_tensor_desc, size); | |||
@@ -554,6 +613,17 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ | |||
ge_tensor_desc.GetDataType()); | |||
tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | |||
tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | |||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
auto status = ge_tensor_desc.GetShapeRange(shape_range); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Get shape range failed!"); | |||
return tensor_desc; | |||
} | |||
status = tensor_desc.SetShapeRange(shape_range); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Set shape range failed!"); | |||
return tensor_desc; | |||
} | |||
int64_t size = 0; | |||
(void)TensorUtils::GetSize(ge_tensor_desc, size); | |||
tensor_desc.SetSize(size); | |||
@@ -28,6 +28,7 @@ | |||
#include <cstring> | |||
#include <fstream> | |||
#include <iomanip> | |||
#include <queue> | |||
#include "./ge_context.h" | |||
#include "debug/ge_util.h" | |||
@@ -390,8 +391,8 @@ GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDa | |||
return GRAPH_FAILED; | |||
} | |||
if ((RemoveEdge(src, dst) != GRAPH_SUCCESS) || | |||
(AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS)) { | |||
(void)RemoveEdge(src, dst); | |||
if (AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | |||
dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
@@ -399,7 +400,7 @@ GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDa | |||
OutControlAnchorPtr new_out_ctrl_anchor = insert_node->GetOutControlAnchor(); | |||
GE_CHECK_NOTNULL(new_out_ctrl_anchor); | |||
for (InControlAnchorPtr peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
for (const InControlAnchorPtr &peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || | |||
(AddEdge(new_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { | |||
GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | |||
@@ -706,7 +707,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||
GELOGE(GRAPH_FAILED, "File name is too longer!"); | |||
return; | |||
} | |||
std::unique_ptr<char> real_path(new (std::nothrow) char[PATH_MAX]{0}); | |||
std::unique_ptr<char[]> real_path(new (std::nothrow) char[PATH_MAX]{0}); | |||
if (real_path == nullptr) { | |||
GELOGE(GRAPH_FAILED, "New real_path failed."); | |||
return; | |||
@@ -1275,6 +1276,423 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindR | |||
return result; | |||
} | |||
/// | |||
/// Get reference-mapping of all data_anchors in graph | |||
/// @param [in] graph | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol) { | |||
GE_CHECK_NOTNULL(graph); | |||
for (auto &node : graph->GetAllNodes()) { | |||
// in_data_anchor | |||
if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
// out_data_anchor | |||
if (HandleOutAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
GE_LOGE("Find ref_mapping for out_data_anchors of node %s failed.", node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// Get reference-mapping for in_data_anchors of node | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol) { | |||
GE_CHECK_NOTNULL(node); | |||
if (NodeUtils::IsSubgraphOutput(node)) { | |||
return HandleSubgraphOutput(node, symbol_to_anchors, anchor_to_symbol); | |||
} | |||
if (NodeUtils::IsSubgraphInput(node)) { | |||
return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); | |||
} | |||
std::string type = node->GetType(); | |||
if ((type == MERGE) || (type == STREAMMERGE)) { | |||
return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); | |||
} | |||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
NodeIndexIO cur_node_info = NodeIndexIO(node, in_data_anchor->GetIdx(), kIn); | |||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
if (peer_out_anchor == nullptr) { | |||
std::string symbol = cur_node_info.ToString(); | |||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
symbol_to_anchors[symbol] = {cur_node_info}; | |||
anchor_to_symbol[symbol] = symbol; | |||
} else { | |||
NodeIndexIO exist_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); | |||
if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
GE_LOGE("Update symbol mapping failed."); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// Get reference-mapping for out_data_anchors of node | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol) { | |||
GE_CHECK_NOTNULL(node); | |||
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
NodeIndexIO cur_node_info = NodeIndexIO(node, out_data_anchor->GetIdx(), kOut); | |||
if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { | |||
continue; | |||
} | |||
int32_t reuse_in_index = -1; | |||
if (IsRefFromInput(out_data_anchor, reuse_in_index)) { | |||
NodeIndexIO exist_node_info = NodeIndexIO(node, reuse_in_index, kIn); | |||
if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
GE_LOGE("Update symbol mapping failed."); | |||
return GRAPH_FAILED; | |||
} | |||
} else { | |||
std::string symbol = cur_node_info.ToString(); | |||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
symbol_to_anchors.emplace(std::make_pair(symbol, std::vector<NodeIndexIO>{cur_node_info})); | |||
anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// Handle input of subgraph | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol) { | |||
GE_CHECK_NOTNULL(node); | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
// Data in subgraph | |||
uint32_t index = 0; | |||
if (!ge::AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index)) { | |||
GE_LOGE("Get attr ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
NodePtr parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||
GE_CHECK_NOTNULL(parent_node); | |||
InDataAnchorPtr parent_in_anchor = parent_node->GetInDataAnchor(index); | |||
GE_CHECK_NOTNULL(parent_in_anchor); | |||
OutDataAnchorPtr peer_out_anchor = parent_in_anchor->GetPeerOutAnchor(); | |||
if (peer_out_anchor != nullptr) { | |||
// Data has and only has one input | |||
NodeIndexIO cur_node_info = NodeIndexIO(node, 0, kIn); | |||
NodeIndexIO exist_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); | |||
if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
GE_LOGE("Update symbol mapping failed."); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// Handle input of Merge op | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol) { | |||
GE_CHECK_NOTNULL(node); | |||
std::vector<NodeIndexIO> exist_node_infos; | |||
std::vector<NodeIndexIO> cur_node_infos; | |||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
if (peer_out_anchor == nullptr) { | |||
std::string next_name; | |||
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { | |||
ComputeGraphPtr graph = node->GetOwnerComputeGraph(); | |||
GE_CHECK_NOTNULL(graph); | |||
ge::NodePtr next_node = graph->FindNode(next_name); | |||
GE_CHECK_NOTNULL(next_node); | |||
// NextIteration has and only has one output | |||
peer_out_anchor = next_node->GetOutDataAnchor(0); | |||
GE_CHECK_NOTNULL(peer_out_anchor); | |||
cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); | |||
cur_node_infos.emplace_back(NodeIndexIO(next_node, peer_out_anchor->GetIdx(), kOut)); | |||
} | |||
} else { | |||
cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); | |||
exist_node_infos.emplace_back(NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut)); | |||
} | |||
} | |||
size_t anchor_nums = 0; | |||
NodeIndexIO max_node_index_io(nullptr, 0, kOut); | |||
for (auto &temp_node_info : exist_node_infos) { | |||
auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); | |||
if (iter1 != anchor_to_symbol.end()) { | |||
std::string temp_symbol = iter1->second; | |||
auto iter2 = symbol_to_anchors.find(temp_symbol); | |||
if (iter2 != symbol_to_anchors.end()) { | |||
if (iter2->second.size() > anchor_nums) { | |||
max_node_index_io = temp_node_info; | |||
anchor_nums = iter2->second.size(); | |||
} | |||
} | |||
} | |||
} | |||
std::string symbol; | |||
for (auto &temp_node_info : exist_node_infos) { | |||
if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != | |||
GRAPH_SUCCESS) || | |||
symbol.empty()) { | |||
GE_LOGE("Union symbol map anchor1:%s & anchor2:%s.", max_node_index_io.ToString().c_str(), | |||
temp_node_info.ToString().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
auto iter = symbol_to_anchors.find(symbol); | |||
if (iter != symbol_to_anchors.end()) { | |||
for (auto &temp_node_info : cur_node_infos) { | |||
GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); | |||
iter->second.emplace_back(temp_node_info); | |||
anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// Handle output of subgraph | |||
/// @param [in] node | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol) { | |||
GE_CHECK_NOTNULL(node); | |||
ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); | |||
GE_CHECK_NOTNULL(owner_graph); | |||
NodePtr parent_node = owner_graph->GetParentNode(); | |||
GE_CHECK_NOTNULL(parent_node); | |||
OpDescPtr op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(peer_out_anchor); | |||
GeTensorDesc in_tensor = op_desc->GetInputDesc(in_data_anchor->GetIdx()); | |||
uint32_t index = 0; | |||
if (!ge::AttrUtils::GetInt(in_tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { | |||
continue; | |||
} | |||
GE_CHECK_NOTNULL(parent_node->GetOutDataAnchor(index)); | |||
// Union symbol of peer_out_anchor & parent_out_anchor | |||
NodeIndexIO peer_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); | |||
NodeIndexIO parent_node_info = NodeIndexIO(parent_node, index, kOut); | |||
std::string symbol; | |||
if ((UnionSymbolMapping(peer_node_info, parent_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != | |||
GRAPH_SUCCESS) || | |||
symbol.empty()) { | |||
GE_LOGE("Union symbol map anchor1:%s, anchor2:%s.", peer_node_info.ToString().c_str(), | |||
parent_node_info.ToString().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
NodeIndexIO cur_node_info = NodeIndexIO(node, in_data_anchor->GetIdx(), kIn); | |||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
symbol_to_anchors[symbol].emplace_back(cur_node_info); | |||
anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// Union ref-mapping | |||
/// @param [in] exist_node_info1 | |||
/// @param [in] exist_node_info2 | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @param [out] symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol) { | |||
std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; | |||
std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; | |||
if (symbol1 == symbol2) { | |||
symbol = symbol1; | |||
GELOGI("no need to union."); | |||
return GRAPH_SUCCESS; | |||
} | |||
auto iter1 = symbol_to_anchors.find(symbol1); | |||
auto iter2 = symbol_to_anchors.find(symbol2); | |||
if ((iter1 == symbol_to_anchors.end()) || (iter2 == symbol_to_anchors.end())) { | |||
GE_LOGE("symbol %s or %s not exist.", symbol1.c_str(), symbol2.c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
auto &max_iter = (iter1->second.size() > iter2->second.size() ? iter1 : iter2); | |||
auto &min_iter = (iter1->second.size() > iter2->second.size() ? iter2 : iter1); | |||
symbol = (iter1->second.size() > iter2->second.size() ? symbol1 : symbol2); | |||
std::string min_symbol = (iter1->second.size() > iter2->second.size() ? symbol2 : symbol1); | |||
for (auto &node_index_io : min_iter->second) { | |||
GELOGD("Update anchor %s, symbol %s.", node_index_io.ToString().c_str(), symbol.c_str()); | |||
max_iter->second.emplace_back(node_index_io); | |||
auto iter = anchor_to_symbol.find(node_index_io.ToString()); | |||
if (iter == anchor_to_symbol.end()) { | |||
GE_LOGE("anchor %s not exist.", node_index_io.ToString().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
if (iter->second != min_symbol) { | |||
GELOGW("not expected symbol of anchor %s, expect %s but %s exactly.", iter->first.c_str(), min_symbol.c_str(), | |||
iter->second.c_str()); | |||
} | |||
iter->second = symbol; | |||
} | |||
GELOGI("Union symbol %s and %s succ.", symbol.c_str(), min_symbol.c_str()); | |||
symbol_to_anchors.erase(min_iter); | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// Update symbol mapping with a new reference pair | |||
/// @param [in] cur_node_info | |||
/// @param [in] exist_node_info | |||
/// @param [out] symbol_to_anchors | |||
/// @param [out] anchor_to_symbol | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, | |||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol) { | |||
auto iter1 = anchor_to_symbol.find(exist_node_info.ToString()); | |||
if (iter1 == anchor_to_symbol.end()) { | |||
GE_LOGE("data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", | |||
exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
std::string symbol = iter1->second; | |||
auto iter2 = symbol_to_anchors.find(symbol); | |||
if (iter2 == symbol_to_anchors.end()) { | |||
GE_LOGE("symbol %s not found.", symbol.c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
iter2->second.emplace_back(cur_node_info); | |||
anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// Check if out_data_anchor is reference of input | |||
/// @param [in] out_data_anchor | |||
/// @param [out] reuse_in_index | |||
/// @return bool | |||
/// | |||
bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { | |||
if (out_data_anchor == nullptr) { | |||
GELOGW("out_data_anchor is NULL."); | |||
return false; | |||
} | |||
int32_t output_index = out_data_anchor->GetIdx(); | |||
// pass-through op | |||
NodePtr node = out_data_anchor->GetOwnerNode(); | |||
std::string type = node->GetType(); | |||
const std::set<std::string> pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; | |||
if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { | |||
reuse_in_index = output_index; | |||
GELOGI("Pass-Through node name[%s] index[%u].", node->GetName().c_str(), reuse_in_index); | |||
return true; | |||
} | |||
// Merge op 0th output | |||
if ((type == MERGE) && (output_index == 0)) { | |||
reuse_in_index = 0; | |||
GELOGI("Merge name[%s] output_index[0].", node->GetName().c_str()); | |||
return true; | |||
} | |||
// ref op | |||
OpDescPtr op_desc = node->GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
GELOGW("op_desc is NULL."); | |||
return false; | |||
} | |||
bool is_ref = false; | |||
(void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); | |||
if (is_ref) { | |||
const string &output_name = op_desc->GetOutputNameByIndex(output_index); | |||
for (const auto &input_name : op_desc->GetAllInputNames()) { | |||
if (!input_name.empty() && (output_name == input_name)) { | |||
reuse_in_index = op_desc->GetInputIndexByName(input_name); | |||
GELOGI("Reference name[%s] output[%s][%u] ref to input[%s][%d].", op_desc->GetName().c_str(), | |||
output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); | |||
return true; | |||
} | |||
} | |||
} | |||
// reuse input | |||
auto output_op_desc = op_desc->GetOutputDescPtr(output_index); | |||
bool reuse_input = false; | |||
if (output_op_desc != nullptr) { | |||
if ((TensorUtils::GetReuseInput(*output_op_desc, reuse_input) == GRAPH_SUCCESS) && reuse_input) { | |||
uint32_t reuse_input_index = 0; | |||
if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { | |||
reuse_in_index = static_cast<int32_t>(reuse_input_index); | |||
GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, | |||
reuse_in_index); | |||
return true; | |||
} | |||
} | |||
} | |||
return false; | |||
} | |||
/// | |||
/// @brief Add node to graph | |||
/// @param [in] op_desc | |||
@@ -1561,13 +1979,14 @@ CompleteGraphBuilder &CompleteGraphBuilder::SetOutputMapping(const std::map<uint | |||
/// | |||
ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { | |||
owner_graph_ = shared_ptr<ComputeGraph>(new (std::nothrow) ComputeGraph(name_)); | |||
if (owner_graph_ == nullptr) { | |||
if ((owner_graph_ == nullptr) || (parent_node_ == nullptr)) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "graph is NULL."; | |||
error_msg = "graph / parent_node is NULL."; | |||
return nullptr; | |||
} | |||
owner_graph_->SetParentNode(parent_node_); | |||
owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); | |||
BuildNodes(error_code, error_msg); | |||
if (error_code != GRAPH_SUCCESS) { | |||
@@ -1584,41 +2003,58 @@ ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string | |||
return nullptr; | |||
} | |||
BuildInputs(error_code, error_msg); | |||
AddDataNodes(error_code, error_msg); | |||
if (error_code != GRAPH_SUCCESS) { | |||
return nullptr; | |||
} | |||
BuildOutputs(error_code, error_msg); | |||
AddRetValNodes(error_code, error_msg); | |||
if (error_code != GRAPH_SUCCESS) { | |||
return nullptr; | |||
} | |||
if (AddNetOutputNode(error_code, error_msg) == nullptr) { | |||
// ATTR_NAME_SESSION_GRAPH_ID | |||
std::string graph_id; | |||
if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "Get attr session_graph_id failed."; | |||
return nullptr; | |||
} | |||
if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "Set attr session_graph_id failed."; | |||
return nullptr; | |||
} | |||
// refresh node name | |||
for (const NodePtr &node : owner_graph_->GetDirectNode()) { | |||
if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { | |||
continue; | |||
} | |||
node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); | |||
} | |||
return owner_graph_; | |||
} | |||
/// | |||
/// @brief Build inputs | |||
/// @brief Add data nodes | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &error_msg) { | |||
void CompleteGraphBuilder::AddDataNodes(graphStatus &error_code, std::string &error_msg) { | |||
for (auto &input : graph_inputs_) { | |||
NodePtr data_node = AddDateNode(input.first, error_code, error_msg); | |||
NodePtr data_node = AddDataNode(input.first, error_code, error_msg); | |||
if (data_node == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildInputs failed: add node Data:" + std::to_string(input.first) + +" failed."; | |||
error_msg = "AddDataNodes failed: add node Data:" + std::to_string(input.first) + +" failed."; | |||
return; | |||
} | |||
if (owner_graph_->AddInputNode(data_node) == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildInputs failed: add input node Data:" + std::to_string(input.first) + +" failed."; | |||
error_msg = "AddDataNodes failed: add input node Data:" + std::to_string(input.first) + +" failed."; | |||
return; | |||
} | |||
@@ -1627,7 +2063,7 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err | |||
std::vector<uint32_t> anchor_indes = input.second.second; | |||
if (input_names.size() != anchor_indes.size()) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildInputs failed: num of input_names and indexs not equal."; | |||
error_msg = "AddDataNodes failed: num of input_names and indexs not equal."; | |||
return; | |||
} | |||
if (input_names.empty()) { | |||
@@ -1641,29 +2077,29 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err | |||
auto iter = node_names_.find(input_name); | |||
if (iter == node_names_.end()) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildInputs failed: node " + input_name + " not exist in graph."; | |||
error_msg = "AddDataNodes failed: node " + input_name + " not exist in graph."; | |||
return; | |||
} | |||
NodePtr in_node = node_names_[input_name]; | |||
if (in_node == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildInputs failed: node " + input_name + " is NULL."; | |||
error_msg = "AddDataNodes failed: node " + input_name + " is NULL."; | |||
return; | |||
} | |||
if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildInputs failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + | |||
error_msg = "AddDataNodes failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + | |||
":" + std::to_string(ind) + " failed."; | |||
return; | |||
} | |||
} | |||
GELOGD("BuildInputs : Add %u input succ.", input.first); | |||
GELOGD("AddDataNodes : Add %u input succ.", input.first); | |||
} | |||
GELOGD("BuildInputs succ."); | |||
GELOGD("AddDataNodes succ."); | |||
} | |||
/// | |||
@@ -1673,13 +2109,13 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { | |||
NodePtr CompleteGraphBuilder::AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { | |||
std::string data_name = "Data_" + std::to_string(index); | |||
OpDescBuilder op_desc_builder(data_name, "Data"); | |||
OpDescPtr op_desc = op_desc_builder.AddInput("x").AddOutput("y").Build(); | |||
if (op_desc == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildInputs failed: create op_desc " + data_name + " failed."; | |||
error_msg = "AddDataNode failed: create op_desc " + data_name + " failed."; | |||
return nullptr; | |||
} | |||
@@ -1687,7 +2123,7 @@ NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_cod | |||
if (index_iter != input_mapping_.end()) { | |||
if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, index_iter->second)) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildInputs failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; | |||
error_msg = "AddDataNode failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; | |||
return nullptr; | |||
} | |||
} | |||
@@ -1695,189 +2131,83 @@ NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_cod | |||
NodePtr data_node = owner_graph_->AddNode(op_desc); | |||
if (data_node == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildInputs failed: add node " + data_name + " failed."; | |||
error_msg = "AddDataNode failed: add node " + data_name + " failed."; | |||
return nullptr; | |||
} | |||
node_names_[data_name] = data_node; | |||
return data_node; | |||
} | |||
/// | |||
/// @brief Build outputs | |||
/// @brief Add RetVal nodes | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void CompleteGraphBuilder::BuildOutputs(graphStatus &error_code, std::string &error_msg) { | |||
std::map<std::string, std::vector<int32_t>> out_nodes_map; | |||
std::vector<std::pair<NodePtr, int32_t>> out_nodes_info; | |||
for (auto &pair : graph_outputs_) { | |||
std::string output = pair.first; | |||
int32_t ind = pair.second; | |||
auto out_iter = node_names_.find(output); | |||
void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &error_msg) { | |||
size_t output_num = graph_outputs_.size(); | |||
for (size_t i = 0; i < output_num; i++) { | |||
int32_t index = graph_outputs_[i].second; | |||
auto out_iter = node_names_.find(graph_outputs_[i].first); | |||
if (out_iter == node_names_.end()) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildOutputs failed: node " + output + " not exist in graph."; | |||
error_msg = "AddRetValNode failed: node " + graph_outputs_[i].first + " not exist in graph."; | |||
return; | |||
} | |||
NodePtr out_node = node_names_[output]; | |||
if (out_node == nullptr) { | |||
NodePtr node = out_iter->second; | |||
if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildOutputs failed: node " + output + " is NULL."; | |||
error_msg = "AddRetValNode failed: node is NULL."; | |||
return; | |||
} | |||
OutDataAnchorPtr out_anchor = out_node->GetOutDataAnchor(ind); | |||
if (out_anchor == nullptr) { | |||
std::string name = node->GetName() + "_RetVal"; | |||
OpDescPtr ret_val_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); | |||
if (ret_val_desc == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "BuildOutputs failed: anchor " + output + ":" + std::to_string(ind) + " is NULL."; | |||
error_msg = "AddRetValNode " + name + " failed: op_desc is NULL."; | |||
return; | |||
} | |||
auto iter = out_nodes_map.find(output); | |||
if (iter == out_nodes_map.end()) { | |||
std::vector<int32_t> vec = {ind}; | |||
out_nodes_map[output] = vec; | |||
} else { | |||
out_nodes_map[output].emplace_back(ind); | |||
ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); | |||
if ((ret_val_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) || | |||
(ret_val_desc->AddOutputDesc(tensor) != GRAPH_SUCCESS)) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "AddRetValNode " + name + " failed: add input_desc / output_desc failed."; | |||
return; | |||
} | |||
out_nodes_info.emplace_back(std::make_pair(out_node, ind)); | |||
GELOGD("BuildOutputs : AddOutputAnchor %s:%u succ.", output.c_str(), ind); | |||
} | |||
owner_graph_->SetGraphOutNodes(out_nodes_map); | |||
owner_graph_->SetGraphOutNodesInfo(out_nodes_info); | |||
GELOGD("BuildOutputs succ."); | |||
} | |||
/// | |||
/// @brief Add NetOutput node | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return NodePtr | |||
/// | |||
NodePtr CompleteGraphBuilder::AddNetOutputNode(graphStatus &error_code, std::string &error_msg) { | |||
std::string log_msg = "AddNetOutputNode name:" + std::string(kNodeNameNetOutput) + ", type:" + NETOUTPUT; | |||
OpDescPtr net_output_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(kNodeNameNetOutput, NETOUTPUT)); | |||
if (net_output_desc == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = log_msg + " failed: op_desc is NULL."; | |||
return nullptr; | |||
} | |||
std::vector<std::pair<NodePtr, int32_t>> out_nodes_info = owner_graph_->GetGraphOutNodesInfo(); | |||
error_code = BuildInOutForNetOutput(out_nodes_info, net_output_desc); | |||
if (error_code != GRAPH_SUCCESS) { | |||
error_msg = log_msg + " failed: add input/output tensor failed."; | |||
return nullptr; | |||
} | |||
NodePtr net_output_node = owner_graph_->AddNode(net_output_desc); | |||
if (net_output_node == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = log_msg + " failed: add node failed."; | |||
return nullptr; | |||
} | |||
error_code = AddEdgeForNetOutput(out_nodes_info, net_output_node); | |||
if (error_code != GRAPH_SUCCESS) { | |||
error_msg = log_msg + " failed: link edge failed."; | |||
return nullptr; | |||
} | |||
GELOGD("%s succ.", log_msg.c_str()); | |||
return net_output_node; | |||
} | |||
/// | |||
/// @brief Add input/output tensor for NetOutput node | |||
/// @param [in] out_nodes_info | |||
/// @param [out] net_output_desc | |||
/// @return graphStatus | |||
/// | |||
graphStatus CompleteGraphBuilder::BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
OpDescPtr &net_output_desc) { | |||
size_t output_num = out_nodes_info.size(); | |||
for (size_t i = 0; i < output_num; i++) { | |||
NodePtr src_node = out_nodes_info[i].first; | |||
uint32_t src_index = out_nodes_info[i].second; | |||
if ((src_node == nullptr) || (src_node->GetOpDesc() == nullptr)) { | |||
GE_LOGE("AddInOutForNetOutputOp failed: src_node is NULL."); | |||
return GRAPH_FAILED; | |||
if (!(ge::AttrUtils::SetStr(ret_val_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_RetVal") && | |||
ge::AttrUtils::SetInt(ret_val_desc, RETVAL_ATTR_NAME_INDEX, i))) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "AddRetValNode " + name + " failed: set FRAMEWORK_ORIGINAL_TYPE / RETVAL_ATTR_NAME_INDEX failed."; | |||
return; | |||
} | |||
ge::GeTensorDesc in_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); | |||
auto iter = output_mapping_.find(i); | |||
if (iter != output_mapping_.end()) { | |||
if (!ge::AttrUtils::SetInt(in_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||
GE_LOGE("AddInOutForNetOutputOp failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||
return GRAPH_FAILED; | |||
if (!ge::AttrUtils::SetInt(ret_val_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "AddRetValNode " + name + " failed: set attr PARENT_NODE_INDEX failed."; | |||
return; | |||
} | |||
} | |||
if (net_output_desc->AddInputDesc(in_desc) != SUCCESS) { | |||
GE_LOGE("AddInOutForNetOutputOp failed: add input_desc failed."); | |||
return GRAPH_FAILED; | |||
} | |||
ge::GeTensorDesc out_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); | |||
TensorUtils::SetOutputTensor(out_desc, true); | |||
if (net_output_desc->AddOutputDesc(out_desc) != SUCCESS) { | |||
GE_LOGE("AddInOutForNetOutputOp failed: add output_desc failed."); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
GELOGD("Add input/output tensor for NetOutput node succ."); | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// @brief Add edge for NetOutput node | |||
/// @param [in] out_nodes_info | |||
/// @param [out] net_output_node | |||
/// @return graphStatus | |||
/// | |||
graphStatus CompleteGraphBuilder::AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
const NodePtr &net_output_node) { | |||
if (net_output_node == nullptr) { | |||
GE_LOGE("AddEdgeForNetOutputOp failed: NetOutput is NULL."); | |||
return GRAPH_FAILED; | |||
} | |||
size_t out_num = out_nodes_info.size(); | |||
for (size_t i = 0; i < out_num; i++) { | |||
NodePtr src_node = out_nodes_info[i].first; | |||
uint32_t ind = out_nodes_info[i].second; | |||
if (src_node == nullptr) { | |||
GE_LOGE("AddEdgeForNetOutputOp failed: src_node is NULL."); | |||
return GRAPH_FAILED; | |||
} | |||
if (GraphUtils::AddEdge(src_node->GetOutDataAnchor(ind), net_output_node->GetInDataAnchor(i)) != GRAPH_SUCCESS) { | |||
GE_LOGE("Add data-edge %s:%u->%s:%zu failed.", src_node->GetName().c_str(), ind, | |||
net_output_node->GetName().c_str(), i); | |||
return GRAPH_FAILED; | |||
NodePtr ret_val_node = owner_graph_->AddNode(ret_val_desc); | |||
if (ret_val_node == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "AddRetValNode " + name + " failed: add node failed."; | |||
return; | |||
} | |||
} | |||
std::vector<NodePtr> leaf_nodes; | |||
for (auto &node : owner_graph_->GetDirectNode()) { | |||
if (node->GetOutNodes().empty()) { | |||
leaf_nodes.emplace_back(node); | |||
} | |||
} | |||
for (auto &node : leaf_nodes) { | |||
if (GraphUtils::AddEdge(node->GetOutControlAnchor(), net_output_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||
GE_LOGE("Add ctrl-edge %s->%s failed.", node->GetName().c_str(), net_output_node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
if (GraphUtils::AddEdge(node->GetOutDataAnchor(index), ret_val_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { | |||
error_code = GRAPH_FAILED; | |||
error_msg = "AddRetValNode " + name + " failed: add data-edge " + node->GetName() + ":" + std::to_string(index) + | |||
"->" + ret_val_node->GetName() + ":0 failed."; | |||
return; | |||
} | |||
} | |||
GELOGD("Add edge for NetOutput node succ."); | |||
return GRAPH_SUCCESS; | |||
GELOGD("AddRetValNodes succ."); | |||
} | |||
/// | |||
@@ -1999,4 +2329,60 @@ void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string & | |||
GELOGD("Build exist nodes succ."); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec) { | |||
std::vector<NodePtr> stack_input; | |||
std::map<NodePtr, uint32_t> map_in_edge_num; | |||
graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); | |||
if (ret != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Sort nodes failed."); | |||
return GRAPH_FAILED; | |||
} | |||
const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; | |||
std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, | |||
[](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); | |||
std::queue<NodePtr> stack; | |||
NodePtr cur_node = nullptr; | |||
std::map<string, NodePtr> name_node_map; | |||
vector<string> nodes_name; | |||
while (!stack_input.empty() || !stack.empty()) { | |||
if (!stack.empty()) { | |||
cur_node = stack.front(); | |||
stack.pop(); | |||
} else { | |||
cur_node = stack_input.back(); | |||
stack_input.pop_back(); | |||
} | |||
node_vec.emplace_back(cur_node); | |||
compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); | |||
for (const auto &iter : name_node_map) { | |||
nodes_name.emplace_back(iter.first); | |||
} | |||
std::sort(nodes_name.begin(), nodes_name.end()); | |||
for (const auto &iter : nodes_name) { | |||
stack.push(name_node_map[iter]); | |||
} | |||
name_node_map.clear(); | |||
nodes_name.clear(); | |||
} | |||
// If they are not equal, there is a closed loop | |||
if (node_vec.size() != compute_graph->nodes_.size()) { | |||
std::set<Node *> itered_nodes_set; | |||
for (auto &node : node_vec) { | |||
itered_nodes_set.insert(node.get()); | |||
} | |||
GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", | |||
compute_graph->nodes_.size(), node_vec.size()); | |||
for (auto &node : compute_graph->nodes_) { | |||
if (itered_nodes_set.count(node.get()) == 0) { | |||
GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); | |||
} | |||
} | |||
return GRAPH_FAILED; | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
} // namespace ge |
@@ -21,6 +21,7 @@ | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/anchor.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/types.h" | |||
#include "utils/tensor_utils.h" | |||
#include "utils/type_utils.h" | |||
@@ -28,6 +29,26 @@ namespace ge { | |||
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{}; | |||
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{}; | |||
bool OpShapeIsUnknown(const OpDescPtr &desc) { | |||
for (const auto &ptr : desc->GetAllInputsDescPtr()) { | |||
auto ge_shape = ptr->GetShape(); | |||
for (const auto &dim : ge_shape.GetDims()) { | |||
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||
return true; | |||
} | |||
} | |||
} | |||
for (const auto &ptr : desc->GetAllOutputsDescPtr()) { | |||
auto ge_shape = ptr->GetShape(); | |||
for (const auto &dim : ge_shape.GetDims()) { | |||
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||
return true; | |||
} | |||
} | |||
} | |||
return false; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, | |||
const uint32_t &event_id) { | |||
GE_CHECK_NOTNULL(node); | |||
@@ -282,18 +303,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||
GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | |||
continue; | |||
} | |||
auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->GetInputDescPtr(peer_anchor->GetIdx()); | |||
auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); | |||
if (peer_input_desc == nullptr) { | |||
GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); | |||
continue; | |||
} | |||
output_tensor.SetOriginFormat(peer_input_desc->GetOriginFormat()); | |||
output_tensor.SetFormat(peer_input_desc->GetFormat()); | |||
auto peer_op_desc = peer_anchor->GetOwnerNode()->GetOpDesc(); | |||
GE_IF_BOOL_EXEC(peer_op_desc == nullptr, GELOGE(GRAPH_FAILED, "peer opdesc is null"); continue); | |||
GE_IF_BOOL_EXEC(peer_op_desc->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor) != GRAPH_SUCCESS, | |||
GELOGE(GRAPH_FAILED, "peer opdesc is null"); | |||
continue); | |||
GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", | |||
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), | |||
output_tensor.GetDataType(), output_tensor.GetOriginDataType()); | |||
peer_input_desc->SetShape(output_tensor.GetShape()); | |||
peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); | |||
peer_input_desc->SetDataType(output_tensor.GetDataType()); | |||
peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); | |||
ge::TensorUtils::SetRealDimCnt(*peer_input_desc, | |||
static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | |||
GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", | |||
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), | |||
peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
@@ -361,6 +387,41 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||
input_desc->SetShape(shape); | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { | |||
auto desc = node.GetOpDesc(); | |||
GE_CHECK_NOTNULL(desc); | |||
auto sub_graph_names = desc->GetSubgraphInstanceNames(); | |||
if (sub_graph_names.empty()) { | |||
is_unknow = OpShapeIsUnknown(desc); | |||
return GRAPH_SUCCESS; | |||
} else { | |||
auto owner_graph = node.GetOwnerComputeGraph(); | |||
GE_CHECK_NOTNULL(owner_graph); | |||
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||
if (root_graph == nullptr) { | |||
GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
for (auto &sub_graph_name : sub_graph_names) { | |||
auto sub_graph = root_graph->GetSubgraph(sub_graph_name); | |||
GE_CHECK_NOTNULL(sub_graph); | |||
for (const auto &node_ptr : sub_graph->GetDirectNode()) { | |||
auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); | |||
if (status != GRAPH_SUCCESS) { | |||
GE_LOGE("get node unknown shape status failed!"); | |||
return status; | |||
} | |||
if (is_unknow) { | |||
return GRAPH_SUCCESS; | |||
} | |||
} | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
std::string NodeUtils::GetNodeType(const Node &node) { | |||
if (node.GetType() != FRAMEWORKOP) { | |||
return node.GetType(); | |||
@@ -381,9 +442,9 @@ ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | |||
return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); | |||
} | |||
graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) { | |||
graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { | |||
if (subgraph == nullptr) { | |||
GE_LOGE("Failed to add subgraph to node %s, null subgraph", node.GetName().c_str()); | |||
GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
auto op_desc = node.GetOpDesc(); | |||
@@ -395,11 +456,105 @@ graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) | |||
GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
op_desc->AddSubgraphInstanceName(subgraph->GetName()); | |||
auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); | |||
if (ret != GRAPH_SUCCESS) { | |||
GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); | |||
return ret; | |||
} | |||
subgraph->SetParentNode(node.shared_from_this()); | |||
subgraph->SetParentGraph(node.GetOwnerComputeGraph()); | |||
root_graph->AddSubgraph(subgraph); | |||
return root_graph->AddSubgraph(subgraph); | |||
} | |||
return GRAPH_SUCCESS; | |||
/// | |||
/// Check if node is input of subgraph | |||
/// @param [in] node | |||
/// @return bool | |||
/// | |||
bool NodeUtils::IsSubgraphInput(const NodePtr &node) { | |||
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || | |||
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { | |||
return false; | |||
} | |||
return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | |||
} | |||
/// | |||
/// Check if node is output of subgraph | |||
/// @param [in] node | |||
/// @return bool | |||
/// | |||
bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { | |||
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || | |||
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { | |||
return false; | |||
} | |||
for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { | |||
if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
/// | |||
/// @brief Get subgraph original input node. | |||
/// @param [in] node | |||
/// @return Node | |||
/// | |||
NodePtr NodeUtils::GetParentInput(const NodePtr &node) { | |||
GE_CHECK_NOTNULL_EXEC(node, return nullptr); | |||
uint32_t parent_index = 0; | |||
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||
return nullptr; | |||
} | |||
// Subgraph Data Node, check for constant input. | |||
const ComputeGraphPtr &graph = node->GetOwnerComputeGraph(); | |||
GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | |||
const NodePtr &parent_node = graph->GetParentNode(); | |||
GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); | |||
const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); | |||
GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); | |||
const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); | |||
return peer_out_anchor->GetOwnerNode(); | |||
} | |||
/// | |||
/// @brief Get subgraph input is constant. | |||
/// @param [in] node | |||
/// @param [out] string | |||
/// @return bool | |||
/// | |||
bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) { | |||
GE_CHECK_NOTNULL_EXEC(in_node, return false); | |||
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||
op_type = in_node->GetType(); | |||
return true; | |||
} | |||
if (in_node->GetType() == DATA) { | |||
std::string const_type; | |||
if (!AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { | |||
return false; | |||
} | |||
if ((const_type == CONSTANT) || (const_type == CONSTANTOP)) { | |||
op_type = const_type; | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
} // namespace ge |
@@ -469,7 +469,7 @@ OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
ge::GeAttrValue::NamedAttrs named_attrs; | |||
ge::GeAttrValue::NAMED_ATTRS named_attrs; | |||
(void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); | |||
vector<ge::GeTensorPtr> copy_weights; | |||
(void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); | |||
@@ -578,7 +578,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { | |||
inputs_.emplace_back(name); | |||
inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); | |||
return *this; | |||
} | |||
/// | |||
/// @brief Add input | |||
/// @param [in] name | |||
/// @param [in] tensor | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name, | |||
const GeTensorDesc &tensor) { | |||
inputs_.emplace_back(std::make_pair(name, tensor)); | |||
return *this; | |||
} | |||
@@ -591,7 +603,22 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, | |||
uint32_t num) { | |||
for (uint32_t i = 0; i < num; i++) { | |||
inputs_.emplace_back(name + std::to_string(i)); | |||
inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); | |||
} | |||
return *this; | |||
} | |||
/// | |||
/// @brief Add dynamic input | |||
/// @param [in] name | |||
/// @param [in] num | |||
/// @param [in] tensor | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput( | |||
const std::string &name, uint32_t num, const GeTensorDesc &tensor) { | |||
for (uint32_t i = 0; i < num; i++) { | |||
inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); | |||
} | |||
return *this; | |||
} | |||
@@ -602,7 +629,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { | |||
outputs_.emplace_back(name); | |||
outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); | |||
return *this; | |||
} | |||
/// | |||
/// @brief Add output | |||
/// @param [in] name | |||
/// @param [in] tensor | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name, | |||
const GeTensorDesc &tensor) { | |||
outputs_.emplace_back(std::make_pair(name, tensor)); | |||
return *this; | |||
} | |||
@@ -615,7 +654,22 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, | |||
uint32_t num) { | |||
for (uint32_t i = 0; i < num; i++) { | |||
outputs_.emplace_back(name + std::to_string(i)); | |||
outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); | |||
} | |||
return *this; | |||
} | |||
/// | |||
/// @brief Add dynamic output | |||
/// @param [in] name | |||
/// @param [in] num | |||
/// @param [in] tensor | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput( | |||
const std::string &name, uint32_t num, const GeTensorDesc &tensor) { | |||
for (uint32_t i = 0; i < num; i++) { | |||
outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); | |||
} | |||
return *this; | |||
} | |||
@@ -632,14 +686,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() | |||
} | |||
for (auto &input : inputs_) { | |||
if (op_desc->AddInputDesc(input, GeTensorDesc()) != GRAPH_SUCCESS) { | |||
if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Add input_desc failed."); | |||
return nullptr; | |||
} | |||
} | |||
for (auto &output : outputs_) { | |||
if (op_desc->AddOutputDesc(output, GeTensorDesc()) != GRAPH_SUCCESS) { | |||
if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Add output_desc failed."); | |||
return nullptr; | |||
} | |||
@@ -647,4 +701,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() | |||
return op_desc; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName( | |||
const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) { | |||
const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); | |||
auto iter = subgraph_names_to_index.find(subgraph_name); | |||
if (iter == subgraph_names_to_index.end()) { | |||
GELOGE(GRAPH_PARAM_INVALID, | |||
"Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists", | |||
subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||
subgraph_name.c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); | |||
} | |||
} // namespace ge |
@@ -282,6 +282,7 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
case FORMAT_FRACTAL_Z_3D: | |||
case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | |||
case FORMAT_NDC1HWC0: | |||
case FORMAT_FRACTAL_Z_C04: | |||
graph_status = CalcElementCntByDims(dims, element_cnt); | |||
break; | |||
default: | |||
@@ -56,6 +56,7 @@ static const std::map<Format, std::string> kFormatToStringMap = { | |||
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||
{FORMAT_CN, "CN"}, | |||
{FORMAT_NC, "NC"}, | |||
{FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, | |||
{FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||
{FORMAT_ALL, "ALL"}}; | |||
@@ -76,7 +77,8 @@ static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||
"FRACTAL_NZ", | |||
"NDC1HWC0", | |||
"FORMAT_FRACTAL_Z_3D", | |||
"FORMAT_FRACTAL_Z_3D_TRANSPOSE"}; | |||
"FORMAT_FRACTAL_Z_3D_TRANSPOSE" | |||
"FORMAT_FRACTAL_ZN_LSTM"}; | |||
static const std::map<std::string, Format> kDataFormatMap = { | |||
{"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; | |||
@@ -119,6 +121,7 @@ static const std::map<std::string, Format> kStringToFormatMap = { | |||
{"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, | |||
{"CN", FORMAT_CN}, | |||
{"NC", FORMAT_NC}, | |||
{"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, | |||
{"FORMAT_RESERVED", FORMAT_RESERVED}, | |||
{"ALL", FORMAT_ALL}}; | |||
@@ -13,15 +13,18 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
# libge_compiler.so & libge_train.so | |||
# libge_compiler.so & libge_runner.so | |||
# will later be integrated into libgraph_runner.so, works for both training and inference | |||
# compiling proto files generates some warnings, use no-unused-variable to suppress them | |||
set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | |||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"../proto/fusion_model.proto" | |||
"../proto/optimizer_priority.proto" | |||
) | |||
file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
file(GLOB PROTO_CLIENT_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"../proto/ge_api.proto" | |||
) | |||
file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"../proto/om.proto" | |||
"../proto/task.proto" | |||
"../proto/insert_op.proto" | |||
@@ -30,57 +33,46 @@ file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"../proto/op_mapping_info.proto" | |||
) | |||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
ge_protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | |||
ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | |||
# include directories | |||
include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||
include_directories(${GE_SOURCE_DIR}) | |||
include_directories(${GE_SOURCE_DIR}/src) | |||
include_directories(${GE_SOURCE_DIR}/inc) | |||
include_directories(${GE_SOURCE_DIR}/inc/common/util) | |||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||
include_directories(${GE_SOURCE_DIR}/inc/framework/common) | |||
include_directories(${GE_SOURCE_DIR}/inc/runtime) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | |||
include_directories(${CMAKE_BINARY_DIR}) | |||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
######### libge_train.so ############# | |||
######### libge_runner.so ############# | |||
# need to remove dependencies on pb files later | |||
file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"client/ge_api.cc" | |||
"common/formats/format_transfers/*.cc" | |||
"common/formats/formats.cc" | |||
"common/formats/utils/formats_trans_utils.cc" | |||
"common/fp16_t.cc" | |||
"common/ge/plugin_manager.cc" | |||
"common/helper/model_cache_helper.cc" | |||
"common/profiling/profiling_manager.cc" | |||
"engine_manager/dnnengine_manager.cc" | |||
"ge_local_engine/engine/host_cpu_engine.cc" | |||
"generator/ge_generator.cc" | |||
"generator/generator_api.cc" | |||
"graph/build/graph_builder.cc" | |||
"graph/build/label_allocator.cc" | |||
"graph/build/logical_stream_allocator.cc" | |||
"graph/build/model_builder.cc" | |||
"graph/build/run_context.cc" | |||
"graph/build/stream_allocator.cc" | |||
"graph/build/stream_graph_optimizer.cc" | |||
"graph/build/task_generator.cc" | |||
"graph/common/bcast.cc" | |||
"graph/common/omg_util.cc" | |||
"graph/common/transop_util.cc" | |||
"graph/build/*.cc" | |||
"graph/common/*.cc" | |||
"graph/execute/graph_execute.cc" | |||
"graph/label/*.cc" | |||
"graph/load/graph_loader.cc" | |||
"graph/load/new_model_manager/cpu_queue_schedule.cc" | |||
"graph/load/new_model_manager/data_dumper.cc" | |||
"graph/load/new_model_manager/data_inputer.cc" | |||
"graph/load/new_model_manager/davinci_model.cc" | |||
"graph/load/new_model_manager/davinci_model_parser.cc" | |||
"graph/load/new_model_manager/model_manager.cc" | |||
"graph/load/new_model_manager/model_output.cc" | |||
"graph/load/new_model_manager/model_utils.cc" | |||
"graph/load/new_model_manager/*.cc" | |||
"graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
"graph/load/new_model_manager/task_info/event_record_task_info.cc" | |||
"graph/load/new_model_manager/task_info/event_wait_task_info.cc" | |||
@@ -89,8 +81,10 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/load/new_model_manager/task_info/hccl_task_info.cc" | |||
"graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" | |||
"graph/load/new_model_manager/task_info/kernel_task_info.cc" | |||
"graph/load/new_model_manager/task_info/label_goto_task_info.cc" | |||
"graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" | |||
"graph/load/new_model_manager/task_info/label_set_task_info.cc" | |||
"graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | |||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | |||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | |||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
@@ -99,15 +93,9 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
"graph/load/new_model_manager/task_info/task_info.cc" | |||
"graph/load/new_model_manager/tbe_handle_store.cc" | |||
"graph/load/output/output.cc" | |||
"graph/manager/graph_context.cc" | |||
"graph/manager/graph_manager.cc" | |||
"graph/manager/graph_manager_utils.cc" | |||
"graph/manager/graph_mem_allocator.cc" | |||
"graph/manager/graph_var_manager.cc" | |||
"graph/manager/*.cc" | |||
"graph/manager/model_manager/event_manager.cc" | |||
"graph/manager/trans_var_data_utils.cc" | |||
"graph/manager/util/debug.cc" | |||
"graph/manager/util/hcom_util.cc" | |||
"graph/manager/util/rt_context_util.cc" | |||
@@ -115,27 +103,10 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/optimize/graph_optimize.cc" | |||
"graph/optimize/optimizer/allreduce_fusion_pass.cc" | |||
"graph/optimize/summary_optimize.cc" | |||
"graph/partition/dynamic_shape_partition.cc" | |||
"graph/partition/engine_place.cc" | |||
"graph/partition/graph_partition.cc" | |||
"graph/passes/addn_pass.cc" | |||
"graph/passes/aicpu_constant_folding_pass.cc" | |||
"graph/passes/assert_pass.cc" | |||
"graph/passes/atomic_addr_clean_pass.cc" | |||
"graph/passes/base_pass.cc" | |||
"graph/passes/cast_remove_pass.cc" | |||
"graph/passes/cast_translate_pass.cc" | |||
"graph/passes/common_subexpression_elimination_pass.cc" | |||
"graph/passes/compile_nodes_pass.cc" | |||
"graph/passes/constant_folding_pass.cc" | |||
"graph/passes/constant_fuse_same_pass.cc" | |||
"graph/passes/control_op_attr_pass.cc" | |||
"graph/passes/control_trigger_pass.cc" | |||
"graph/passes/dimension_adjust_pass.cc" | |||
"graph/passes/dimension_compute_pass.cc" | |||
"graph/passes/dropout_pass.cc" | |||
"graph/passes/end_graph_pass.cc" | |||
"graph/passes/enter_pass.cc" | |||
"graph/passes/flow_ctrl_pass.cc" | |||
"graph/passes/*.cc" | |||
"graph/passes/folding_kernel/add_kernel.cc" | |||
"graph/passes/folding_kernel/broadcast_args_kernel.cc" | |||
"graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | |||
@@ -171,51 +142,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/folding_kernel/sub_kernel.cc" | |||
"graph/passes/folding_kernel/transdata_kernel.cc" | |||
"graph/passes/folding_kernel/unpack_kernel.cc" | |||
"graph/passes/folding_pass.cc" | |||
"graph/passes/get_original_format_pass.cc" | |||
"graph/passes/guarantee_const_pass.cc" | |||
"graph/passes/hccl_memcpy_pass.cc" | |||
"graph/passes/identify_reference_pass.cc" | |||
"graph/passes/identity_pass.cc" | |||
"graph/passes/infershape_pass.cc" | |||
"graph/passes/isolated_op_remove_pass.cc" | |||
"graph/passes/iterator_op_pass.cc" | |||
"graph/passes/link_gen_mask_nodes_pass.cc" | |||
"graph/passes/merge_pass.cc" | |||
"graph/passes/multi_batch_pass.cc" | |||
"graph/passes/net_output_pass.cc" | |||
"graph/passes/next_iteration_pass.cc" | |||
"graph/passes/no_use_reshape_remove_pass.cc" | |||
"graph/passes/pass_manager.cc" | |||
"graph/passes/pass_utils.cc" | |||
"graph/passes/permute_pass.cc" | |||
"graph/passes/placeholder_with_default_pass.cc" | |||
"graph/passes/prevent_gradient_pass.cc" | |||
"graph/passes/print_op_pass.cc" | |||
"graph/passes/prune_pass.cc" | |||
"graph/passes/reshape_remove_pass.cc" | |||
"graph/passes/resource_pair_add_control_pass.cc" | |||
"graph/passes/resource_pair_remove_control_pass.cc" | |||
"graph/passes/same_transdata_breadth_fusion_pass.cc" | |||
"graph/passes/save_pass.cc" | |||
"graph/passes/shape_operate_op_remove_pass.cc" | |||
"graph/passes/snapshot_pass.cc" | |||
"graph/passes/stop_gradient_pass.cc" | |||
"graph/passes/switch_logic_remove_pass.cc" | |||
"graph/passes/switch_op_pass.cc" | |||
"graph/passes/switch_pass.cc" | |||
"graph/passes/transop_breadth_fusion_pass.cc" | |||
"graph/passes/transop_depth_fusion_pass.cc" | |||
"graph/passes/transop_nearby_allreduce_fusion_pass.cc" | |||
"graph/passes/transop_without_reshape_fusion_pass.cc" | |||
"graph/passes/transpose_transdata_pass.cc" | |||
"graph/passes/unused_const_pass.cc" | |||
"graph/passes/unused_op_remove_pass.cc" | |||
"graph/passes/var_is_initialized_op_pass.cc" | |||
"graph/passes/variable_format_pass.cc" | |||
"graph/passes/variable_op_pass.cc" | |||
"graph/passes/variable_prepare_op_pass.cc" | |||
"graph/passes/variable_ref_delete_op_pass.cc" | |||
"graph/preprocess/graph_preprocess.cc" | |||
"graph/preprocess/insert_op/ge_aipp_op.cc" | |||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | |||
@@ -231,22 +157,17 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
) | |||
######### libge_train.so ############# | |||
add_library(ge_train SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
target_compile_definitions(ge_train PRIVATE | |||
######### libge_runner.so ############# | |||
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} ${PROTO_HEADER_HDRS}) | |||
target_compile_definitions(ge_runner PRIVATE | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
DAVINCI_SUPPORT_PROFILING | |||
REUSE_MEMORY=1 | |||
DAVINCI_TRAIN | |||
DAVINCI_CLOUD | |||
FMK_SUPPORT_DEBUG | |||
PLATFORM_CLOUD) | |||
target_link_libraries(ge_train | |||
DAVINCI_CLOUD) | |||
target_link_libraries(ge_runner | |||
graph | |||
ge_common | |||
"-Wl,--whole-archive" | |||
ge_memory | |||
"-Wl,--no-whole-archive" | |||
${PROTOBUF_LIBRARY} | |||
${register} | |||
${c_sec} | |||
@@ -267,33 +188,18 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"common/formats/utils/formats_trans_utils.cc" | |||
"common/fp16_t.cc" | |||
"common/ge/plugin_manager.cc" | |||
"common/helper/model_cache_helper.cc" | |||
"common/profiling/profiling_manager.cc" | |||
"engine_manager/dnnengine_manager.cc" | |||
"ge_local_engine/engine/host_cpu_engine.cc" | |||
"generator/ge_generator.cc" | |||
"generator/generator_api.cc" | |||
"graph/build/graph_builder.cc" | |||
"graph/build/label_allocator.cc" | |||
"graph/build/logical_stream_allocator.cc" | |||
"graph/build/model_builder.cc" | |||
"graph/build/run_context.cc" | |||
"graph/build/stream_allocator.cc" | |||
"graph/build/stream_graph_optimizer.cc" | |||
"graph/build/task_generator.cc" | |||
"graph/common/bcast.cc" | |||
"graph/common/omg_util.cc" | |||
"graph/common/transop_util.cc" | |||
"graph/build/*.cc" | |||
"graph/common/*.cc" | |||
"graph/execute/graph_execute.cc" | |||
"graph/label/*.cc" | |||
"graph/load/graph_loader.cc" | |||
"graph/load/new_model_manager/cpu_queue_schedule.cc" | |||
"graph/load/new_model_manager/data_dumper.cc" | |||
"graph/load/new_model_manager/data_inputer.cc" | |||
"graph/load/new_model_manager/davinci_model.cc" | |||
"graph/load/new_model_manager/davinci_model_parser.cc" | |||
"graph/load/new_model_manager/model_manager.cc" | |||
"graph/load/new_model_manager/model_output.cc" | |||
"graph/load/new_model_manager/model_utils.cc" | |||
"graph/load/new_model_manager/*.cc" | |||
"graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
"graph/load/new_model_manager/task_info/event_record_task_info.cc" | |||
"graph/load/new_model_manager/task_info/event_wait_task_info.cc" | |||
@@ -301,8 +207,10 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" | |||
"graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" | |||
"graph/load/new_model_manager/task_info/kernel_task_info.cc" | |||
"graph/load/new_model_manager/task_info/label_goto_task_info.cc" | |||
"graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" | |||
"graph/load/new_model_manager/task_info/label_set_task_info.cc" | |||
"graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | |||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | |||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | |||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
@@ -311,41 +219,18 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
"graph/load/new_model_manager/task_info/task_info.cc" | |||
"graph/load/new_model_manager/tbe_handle_store.cc" | |||
"graph/load/output/output.cc" | |||
"graph/manager/graph_context.cc" | |||
"graph/manager/graph_manager.cc" | |||
"graph/manager/graph_manager_utils.cc" | |||
"graph/manager/graph_mem_allocator.cc" | |||
"graph/manager/graph_var_manager.cc" | |||
"graph/manager/*.cc" | |||
"graph/manager/model_manager/event_manager.cc" | |||
"graph/manager/trans_var_data_utils.cc" | |||
"graph/manager/util/debug.cc" | |||
"graph/manager/util/rt_context_util.cc" | |||
"graph/manager/util/variable_accelerate_ctrl.cc" | |||
"graph/optimize/graph_optimize.cc" | |||
"graph/optimize/summary_optimize.cc" | |||
"graph/partition/dynamic_shape_partition.cc" | |||
"graph/partition/engine_place.cc" | |||
"graph/partition/graph_partition.cc" | |||
"graph/passes/addn_pass.cc" | |||
"graph/passes/aicpu_constant_folding_pass.cc" | |||
"graph/passes/assert_pass.cc" | |||
"graph/passes/atomic_addr_clean_pass.cc" | |||
"graph/passes/base_pass.cc" | |||
"graph/passes/cast_remove_pass.cc" | |||
"graph/passes/cast_translate_pass.cc" | |||
"graph/passes/common_subexpression_elimination_pass.cc" | |||
"graph/passes/compile_nodes_pass.cc" | |||
"graph/passes/constant_folding_pass.cc" | |||
"graph/passes/constant_fuse_same_pass.cc" | |||
"graph/passes/control_op_attr_pass.cc" | |||
"graph/passes/control_trigger_pass.cc" | |||
"graph/passes/dimension_adjust_pass.cc" | |||
"graph/passes/dimension_compute_pass.cc" | |||
"graph/passes/dropout_pass.cc" | |||
"graph/passes/end_graph_pass.cc" | |||
"graph/passes/enter_pass.cc" | |||
"graph/passes/flow_ctrl_pass.cc" | |||
"graph/passes/*.cc" | |||
"graph/passes/folding_kernel/add_kernel.cc" | |||
"graph/passes/folding_kernel/broadcast_args_kernel.cc" | |||
"graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | |||
@@ -380,87 +265,33 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/folding_kernel/strided_slice_kernel.cc" | |||
"graph/passes/folding_kernel/sub_kernel.cc" | |||
"graph/passes/folding_kernel/transdata_kernel.cc" | |||
"graph/passes/folding_kernel/transpose_kernel.cc" | |||
"graph/passes/folding_kernel/unpack_kernel.cc" | |||
"graph/passes/folding_pass.cc" | |||
"graph/passes/get_original_format_pass.cc" | |||
"graph/passes/guarantee_const_pass.cc" | |||
"graph/passes/hccl_memcpy_pass.cc" | |||
"graph/passes/identify_reference_pass.cc" | |||
"graph/passes/identity_pass.cc" | |||
"graph/passes/infershape_pass.cc" | |||
"graph/passes/isolated_op_remove_pass.cc" | |||
"graph/passes/iterator_op_pass.cc" | |||
"graph/passes/link_gen_mask_nodes_pass.cc" | |||
"graph/passes/merge_pass.cc" | |||
"graph/passes/multi_batch_pass.cc" | |||
"graph/passes/net_output_pass.cc" | |||
"graph/passes/next_iteration_pass.cc" | |||
"graph/passes/no_use_reshape_remove_pass.cc" | |||
"graph/passes/pass_manager.cc" | |||
"graph/passes/pass_utils.cc" | |||
"graph/passes/permute_pass.cc" | |||
"graph/passes/placeholder_with_default_pass.cc" | |||
"graph/passes/prevent_gradient_pass.cc" | |||
"graph/passes/print_op_pass.cc" | |||
"graph/passes/prune_pass.cc" | |||
"graph/passes/reshape_remove_pass.cc" | |||
"graph/passes/resource_pair_add_control_pass.cc" | |||
"graph/passes/resource_pair_remove_control_pass.cc" | |||
"graph/passes/same_transdata_breadth_fusion_pass.cc" | |||
"graph/passes/save_pass.cc" | |||
"graph/passes/shape_operate_op_remove_pass.cc" | |||
"graph/passes/snapshot_pass.cc" | |||
"graph/passes/stop_gradient_pass.cc" | |||
"graph/passes/switch_logic_remove_pass.cc" | |||
"graph/passes/switch_op_pass.cc" | |||
"graph/passes/switch_pass.cc" | |||
"graph/passes/transop_breadth_fusion_pass.cc" | |||
"graph/passes/transop_depth_fusion_pass.cc" | |||
"graph/passes/transop_nearby_allreduce_fusion_pass.cc" | |||
"graph/passes/transop_without_reshape_fusion_pass.cc" | |||
"graph/passes/transpose_transdata_pass.cc" | |||
"graph/passes/unused_const_pass.cc" | |||
"graph/passes/unused_op_remove_pass.cc" | |||
"graph/passes/var_is_initialized_op_pass.cc" | |||
"graph/passes/variable_format_pass.cc" | |||
"graph/passes/variable_op_pass.cc" | |||
"graph/passes/variable_prepare_op_pass.cc" | |||
"graph/passes/variable_ref_delete_op_pass.cc" | |||
"graph/preprocess/graph_preprocess.cc" | |||
"graph/preprocess/insert_op/ge_aipp_op.cc" | |||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | |||
"graph/preprocess/multi_batch_copy_graph.cc" | |||
"init/gelib.cc" | |||
"ir_build/atc_ir_common.cc" | |||
"ir_build/ge_ir_build.cc" | |||
"model/ge_model.cc" | |||
"omm/csa_interact.cc" | |||
"opskernel_manager/ops_kernel_manager.cc" | |||
"session/inner_session.cc" | |||
"session/session_manager.cc" | |||
"single_op/single_op.cc" | |||
"single_op/single_op_manager.cc" | |||
"single_op/single_op_model.cc" | |||
"single_op/stream_resource.cc" | |||
"single_op/task/build_task_utils.cc" | |||
"single_op/task/op_task.cc" | |||
"single_op/task/tbe_task_builder.cc" | |||
########################################## | |||
# "ir_build/ge_ir_build.cc" | |||
# "offline/atc_ir_common.cc" | |||
"single_op/*.cc" | |||
"single_op/task/*.cc" | |||
) | |||
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
target_compile_definitions(ge_compiler PRIVATE | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
DAVINCI_SUPPORT_PROFILING | |||
REUSE_MEMORY=1 | |||
FMK_HOST_INFER | |||
PLATFORM_CLOUD) | |||
FMK_HOST_INFER) | |||
target_link_libraries(ge_compiler | |||
graph | |||
ge_common | |||
"-Wl,--whole-archive" | |||
ge_memory | |||
"-Wl,--no-whole-archive" | |||
${PROTOBUF_LIBRARY} | |||
${register} | |||
${c_sec} | |||
@@ -469,5 +300,6 @@ target_link_libraries(ge_compiler | |||
${msprof} | |||
${runtime} | |||
${resouce} | |||
${error_manager} | |||
rt | |||
dl) |
@@ -13,21 +13,21 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
# libge_client.so & libge_client_train.so | |||
# libge_client.so | |||
# add all proto files, generate corresponding .h and .cc files | |||
set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | |||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"../../proto/ge_api.proto" | |||
) | |||
file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"../../proto/ge_ir.proto" | |||
"../../proto/task.proto" | |||
"../../proto/om.proto" | |||
"../../proto/insert_op.proto" | |||
) | |||
file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"ge_api.cc" | |||
) | |||
@@ -49,30 +49,9 @@ include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | |||
include_directories(${CMAKE_BINARY_DIR}) | |||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
######### libge_client_train.so ############# | |||
add_library(ge_client_train SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
target_compile_definitions(ge_client_train PRIVATE | |||
Werror | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
REUSE_MEMORY=1 | |||
PLATFORM_CLOUD | |||
DAVINCI_CLOUD) | |||
target_link_libraries(ge_client_train | |||
graph | |||
ge_train | |||
ge_common | |||
${PROTOBUF_LIBRARY} | |||
${register} | |||
${c_sec} | |||
${slog} | |||
${mmpa} | |||
${runtime} | |||
rt | |||
dl) | |||
############ libge_client.so ################ | |||
add_library(ge_client SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
target_compile_definitions(ge_client_train PRIVATE | |||
target_compile_definitions(ge_client PRIVATE | |||
Werror | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
REUSE_MEMORY=1 | |||
@@ -32,17 +32,18 @@ | |||
using domi::GetContext; | |||
using domi::OpRegistry; | |||
using domi::RealPath; | |||
using domi::StringUtils; | |||
using std::map; | |||
using std::string; | |||
using std::vector; | |||
namespace ge { | |||
static const int32_t kMaxStrLen = 128; | |||
namespace { | |||
const int32_t kMaxStrLen = 128; | |||
} | |||
static bool kGeInitialized = false; | |||
static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use | |||
namespace ge { | |||
void GetOpsProtoPath(std::string &opsproto_path) { | |||
GELOGI("Enter get ops proto path schedule"); | |||
const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||
@@ -394,8 +395,8 @@ Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc | |||
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | |||
} | |||
Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<TensorInfo> &inputs, | |||
std::vector<TensorInfo> &outputs, std::function<void(Status)> callback) { | |||
Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs, | |||
RunAsyncCallback callback) { | |||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); | |||
@@ -405,8 +406,7 @@ Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<TensorInfo> & | |||
GELOGW( | |||
"The callback function will not be checked. Please ensure that the implementation of the function is trusted."); | |||
Status ret = | |||
ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, outputs, callback); | |||
Status ret = ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, callback); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "SessionManager RunGraphAsync failed"); | |||
return FAILED; | |||
@@ -28,7 +28,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"debug/memory_dumper.cc" | |||
"fmk_error_codes.cc" | |||
"formats/format_transfers/datatype_transfer.cc" | |||
"formats/format_transfers/format_transfer.cc" | |||
"formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | |||
"formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" | |||
"formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" | |||
@@ -41,6 +40,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" | |||
"formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" | |||
"formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" | |||
"formats/format_transfers/format_transfer_nchw_fz_c04.cc" | |||
"formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" | |||
"formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" | |||
"formats/format_transfers/format_transfer_transpose.cc" | |||
@@ -54,6 +54,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"helper/om_file_helper.cc" | |||
"math/fp16_math.cc" | |||
"model_parser/base.cc" | |||
"model_saver.cc" | |||
"op/attr_value_util.cc" | |||
"op/ge_op_utils.cc" | |||
"properties_manager.cc" | |||
@@ -61,9 +62,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"thread_pool.cc" | |||
"types.cc" | |||
"util.cc" | |||
"model_saver.cc" | |||
############################### | |||
"op/attr_define.cc" | |||
) | |||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
@@ -73,6 +71,7 @@ include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||
include_directories(${CMAKE_CURRENT_LIST_DIR}/op) | |||
include_directories(${GE_SOURCE_DIR}/src/ge) | |||
include_directories(${GE_SOURCE_DIR}/inc) | |||
include_directories(${GE_SOURCE_DIR}/inc/common/util) | |||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||
@@ -96,5 +95,6 @@ target_link_libraries(ge_common | |||
${slog} | |||
${mmpa} | |||
${resource} | |||
${error_manager} | |||
rt | |||
dl) |
@@ -17,7 +17,6 @@ | |||
#include "common/auth/file_saver.h" | |||
#include <fcntl.h> | |||
#include <securec.h> | |||
#include <unistd.h> | |||
#include <cstdlib> | |||
@@ -29,10 +28,6 @@ | |||
#include "framework/common/debug/log.h" | |||
#include "framework/common/util.h" | |||
using domi::CreateDirectory; | |||
using domi::ModelEncryptType; | |||
using ge::ModelBufferData; | |||
namespace { | |||
const int kFileOpSuccess = 0; | |||
} // namespace | |||
@@ -270,4 +265,4 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(co | |||
} | |||
return ret; | |||
} | |||
} // namespace ge | |||
} // namespace ge |
@@ -26,30 +26,26 @@ | |||
#include "graph/buffer.h" | |||
#include "mmpa/mmpa_api.h" | |||
using domi::ModelFileHeader; | |||
using domi::ModelPartition; | |||
using domi::ModelPartitionTable; | |||
struct PROC_PARAM { | |||
uint8_t *model_name; | |||
/* ISV Ek buffer */ | |||
// ISV Ek buffer | |||
uint8_t *model_key; | |||
uint32_t model_key_len; | |||
/* ISV root certificate buffer */ | |||
// ISV root certificate buffer | |||
uint8_t *root_cert; | |||
uint32_t root_cert_len; | |||
/* ISV private key buffer */ | |||
// ISV private key buffer | |||
uint8_t *pri_key; | |||
uint32_t pri_key_len; | |||
/* Raw AI Module Image buffer */ | |||
// Raw AI Module Image buffer | |||
uint8_t *ai_image; | |||
uint32_t ai_image_len; | |||
/* ISV HW key buffer */ | |||
// ISV HW key buffer | |||
uint8_t *hw_key; | |||
uint32_t hw_key_len; | |||
}; | |||
@@ -66,11 +62,11 @@ using std::string; | |||
class FileSaver { | |||
public: | |||
/** | |||
* @ingroup domi_common | |||
* @brief save model, no encryption | |||
* @return Status result | |||
*/ | |||
/// | |||
/// @ingroup domi_common | |||
/// @brief save model, no encryption | |||
/// @return Status result | |||
/// | |||
static Status SaveToFile(const string &file_path, const ge::ModelData &model, | |||
const ModelFileHeader *model_file_header = nullptr); | |||
@@ -84,26 +80,26 @@ class FileSaver { | |||
static Status SaveToFile(const string &file_path, const void *data, int len); | |||
protected: | |||
/** | |||
* @ingroup domi_common | |||
* @brief Check validity of the file path | |||
* @return Status result | |||
*/ | |||
/// | |||
/// @ingroup domi_common | |||
/// @brief Check validity of the file path | |||
/// @return Status result | |||
/// | |||
static Status CheckPath(const string &file_path); | |||
static Status WriteData(const void *data, uint32_t size, int32_t fd); | |||
static Status OpenFile(int32_t &fd, const std::string &file_path); | |||
/** | |||
* @ingroup domi_common | |||
* @brief save model to file | |||
* @param [in] file_path file output path | |||
* @param [in] file_header file header info | |||
* @param [in] data model data | |||
* @param [in] len model length | |||
* @return Status result | |||
*/ | |||
/// | |||
/// @ingroup domi_common | |||
/// @brief save model to file | |||
/// @param [in] file_path file output path | |||
/// @param [in] file_header file header info | |||
/// @param [in] data model data | |||
/// @param [in] len model length | |||
/// @return Status result | |||
/// | |||
static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, | |||
int len); | |||
@@ -16,6 +16,7 @@ | |||
#include "framework/omg/omg_inner_types.h" | |||
using ge::OmgContext; | |||
namespace domi { | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { | |||
static OmgContext context; | |||
@@ -155,7 +155,7 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||
void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||
bool enum2str) { | |||
if (nullptr == field || nullptr == reflection) { | |||
if ((field == nullptr) || (reflection == nullptr)) { | |||
Message2Json(message, black_fields, json, enum2str); | |||
return; | |||
} | |||
@@ -28,7 +28,9 @@ | |||
using std::string; | |||
static const int kInvalidFd = (-1); | |||
namespace { | |||
const int kInvalidFd = (-1); | |||
} // namespace | |||
namespace ge { | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} | |||
@@ -16,7 +16,7 @@ | |||
#include "common/formats/format_transfers/datatype_transfer.h" | |||
#include <stdint.h> | |||
#include <cstdint> | |||
#include <map> | |||
#include <utility> | |||
@@ -27,8 +27,6 @@ | |||
#include "graph/utils/type_utils.h" | |||
#include "securec.h" | |||
using ge::fp16_t; | |||
namespace ge { | |||
namespace formats { | |||
@@ -134,10 +132,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
} | |||
auto trans_mode = iter->second; | |||
if (args.src_data_size == 0) { | |||
GELOGE(PARAM_INVALID, "Invalid src data size %zu", args.src_data_size); | |||
return PARAM_INVALID; | |||
} | |||
int size = GetSizeByDataType(args.dst_data_type); | |||
if (size <= 0) { | |||
GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | |||
@@ -149,6 +143,12 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
return PARAM_INVALID; | |||
} | |||
size_t total_size = static_cast<size_t>(args.src_data_size * size); | |||
result.length = total_size; | |||
if (total_size == 0) { | |||
GELOGI("In TransDataType, total_size is zero, has no data."); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | |||
@@ -162,7 +162,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
return INTERNAL_ERROR; | |||
} | |||
result.data = dst; | |||
result.length = total_size; | |||
return SUCCESS; | |||
} | |||
@@ -21,7 +21,7 @@ | |||
#include <memory> | |||
#include <vector> | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
#include "external/graph/types.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
@@ -1,69 +0,0 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include <map> | |||
#include <utility> | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/utils/type_utils.h" | |||
namespace ge { | |||
namespace formats { | |||
namespace { | |||
struct FormatTransferRegistry { | |||
Status RegisterBuilder(Format src, Format dst, FormatTransferBuilder builder) { | |||
src_dst_builder[src][dst] = std::move(builder); | |||
return SUCCESS; | |||
} | |||
std::map<Format, std::map<Format, FormatTransferBuilder>> src_dst_builder; | |||
}; | |||
FormatTransferRegistry &GetFormatTransferRegistry() { | |||
static FormatTransferRegistry registry; | |||
return registry; | |||
} | |||
} // namespace | |||
std::shared_ptr<FormatTransfer> BuildFormatTransfer(const TransArgs &args) { | |||
auto registry = GetFormatTransferRegistry(); | |||
auto dst_builder = registry.src_dst_builder.find(args.src_format); | |||
if (dst_builder == registry.src_dst_builder.end()) { | |||
return nullptr; | |||
} | |||
auto builder_iter = dst_builder->second.find(args.dst_format); | |||
if (builder_iter == dst_builder->second.end()) { | |||
return nullptr; | |||
} | |||
return builder_iter->second(); | |||
} | |||
bool FormatTransferExists(const TransArgs &args) { | |||
auto registry = GetFormatTransferRegistry(); | |||
auto dst_builder = registry.src_dst_builder.find(args.src_format); | |||
if (dst_builder == registry.src_dst_builder.end()) { | |||
return false; | |||
} | |||
return dst_builder->second.count(args.dst_format) > 0; | |||
} | |||
FormatTransferRegister::FormatTransferRegister(FormatTransferBuilder builder, Format src, Format dst) { | |||
(void)GetFormatTransferRegistry().RegisterBuilder(src, dst, std::move(builder)); | |||
// RegisterBuilder() always return success, no need to check value | |||
} | |||
} // namespace formats | |||
} // namespace ge |
@@ -27,7 +27,9 @@ | |||
namespace ge { | |||
namespace formats { | |||
namespace { | |||
bool CheckDataTypeSupported(const DataType &data_type) { return (data_type == DT_FLOAT || data_type == DT_FLOAT16); } | |||
bool CheckDataTypeSupported(const DataType &data_type) { | |||
return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); | |||
} | |||
Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||
auto src_shape = args.src_shape; | |||
@@ -51,10 +53,11 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
return PARAM_INVALID; | |||
} | |||
if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / kCubeSize + 1 || | |||
auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||
if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || | |||
src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | |||
src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != kCubeSize || | |||
src_shape.at(kC1hwncoc0C0) != kCubeSize) { | |||
src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | |||
src_shape.at(kC1hwncoc0C0) != cube_size) { | |||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||
return PARAM_INVALID; | |||
@@ -78,6 +81,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||
auto c0 = args.src_shape.at(kC1hwncoc0C0); | |||
auto co = args.src_shape.at(kC1hwncoc0Co); | |||
auto c = args.dst_shape.at(kHwcnC); | |||
auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||
int64_t cn = c * n; | |||
int64_t wcn = w * cn; | |||
int64_t coc0 = co * c0; | |||
@@ -93,8 +97,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||
int64_t c_head_addr = w_head_addr + c_idx * n; | |||
for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||
int64_t dst_idx = c_head_addr + n_idx; | |||
int64_t c1_idx = c_idx / kCubeSize; | |||
int64_t c0_idx = c_idx % kCubeSize; | |||
int64_t c1_idx = c_idx / cube_size; | |||
int64_t c0_idx = c_idx % cube_size; | |||
int64_t co_idx = c0_idx; | |||
int64_t src_idx = c1_idx * hwncoc0 + h_idx * wncoc0 + w_idx * ncoc0 + n_idx * coc0 + co_idx * c0 + c0_idx; | |||
auto src_offset = src_idx * size; | |||
@@ -130,6 +134,11 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||
int size = GetSizeByDataType(args.src_data_type); | |||
int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | |||
if (total_size <= 0) { | |||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||
if (total_size == 0 && src_size == 0) { | |||
result.length = static_cast<size_t>(total_size); | |||
return SUCCESS; | |||
} | |||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
return PARAM_INVALID; | |||
@@ -19,7 +19,7 @@ | |||
#include <vector> | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -88,6 +88,11 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||
dst_size *= dim; | |||
} | |||
dst_size *= data_size; | |||
if (dst_size == 0) { | |||
result.length = static_cast<size_t>(dst_size); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
@@ -18,7 +18,7 @@ | |||
#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWCN_FRACTAL_Z_3D_H_ | |||
#include <vector> | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -89,6 +89,11 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||
dst_size *= dim; | |||
} | |||
dst_size *= data_size; | |||
if (dst_size == 0) { | |||
result.length = static_cast<size_t>(dst_size); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
@@ -18,7 +18,7 @@ | |||
#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWNC_FRACTAL_Z_3D_TRANSPOSE_H_ | |||
#include <vector> | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -116,6 +116,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||
Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | |||
int size = GetSizeByDataType(args.src_data_type); | |||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||
if (dst_size == 0) { | |||
result.length = static_cast<size_t>(dst_size); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
@@ -184,6 +189,11 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||
Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | |||
int size = GetSizeByDataType(args.src_data_type); | |||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||
if (dst_size == 0) { | |||
result.length = static_cast<size_t>(dst_size); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
@@ -19,7 +19,7 @@ | |||
#include <vector> | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -119,6 +119,11 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||
int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | |||
int size = GetSizeByDataType(args.src_data_type); | |||
int64_t dst_size = total_ele_cnt * size; | |||
if (dst_size == 0) { | |||
result.length = static_cast<size_t>(dst_size); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
@@ -194,6 +199,11 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||
dst_size *= dim; | |||
} | |||
dst_size *= data_size; | |||
if (dst_size == 0) { | |||
result.length = static_cast<size_t>(dst_size); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
@@ -259,6 +269,11 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||
dst_size *= dim; | |||
} | |||
dst_size *= data_size; | |||
if (dst_size == 0) { | |||
result.length = static_cast<size_t>(dst_size); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
@@ -19,7 +19,7 @@ | |||
#include <vector> | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -117,6 +117,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||
Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | |||
int size = GetSizeByDataType(args.src_data_type); | |||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||
if (dst_size == 0) { | |||
result.length = static_cast<size_t>(dst_size); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
@@ -189,6 +194,11 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||
Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | |||
int size = GetSizeByDataType(args.src_data_type); | |||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||
if (dst_size == 0) { | |||
result.length = static_cast<size_t>(dst_size); | |||
return SUCCESS; | |||
} | |||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||
if (dst == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
@@ -19,7 +19,7 @@ | |||
#include <vector> | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -133,6 +133,12 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||
int size = GetSizeByDataType(args.src_data_type); | |||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||
if (total_size <= 0) { | |||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||
if (total_size == 0 && src_size == 0) { | |||
result.length = static_cast<size_t>(total_size); | |||
return SUCCESS; | |||
} | |||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
return PARAM_INVALID; | |||
@@ -19,7 +19,7 @@ | |||
#include <vector> | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -133,6 +133,12 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||
int size = GetSizeByDataType(args.src_data_type); | |||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||
if (total_size <= 0) { | |||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||
if (total_size == 0 && src_size == 0) { | |||
result.length = static_cast<size_t>(total_size); | |||
return SUCCESS; | |||
} | |||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
return PARAM_INVALID; | |||
@@ -140,6 +146,7 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||
GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
ShapeToString(args.dst_shape).c_str(), total_size); | |||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
@@ -19,7 +19,7 @@ | |||
#include <vector> | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
namespace ge { | |||
namespace formats { | |||