Browse Source

Merge branch 'master' of gitee.com:mindspore/graphengine into baby3

tags/v1.3.0
gengchao4@huawei.com 3 years ago
parent
commit
57f5c7f2f4
14 changed files with 185 additions and 92 deletions
  1. +6
    -0
      build.sh
  2. +2
    -1
      ge/CMakeLists.txt
  3. +40
    -1
      ge/executor/CMakeLists.txt
  4. +1
    -5
      ge/graph/load/model_manager/model_manager.cc
  5. +10
    -0
      ge/graph/manager/graph_manager.cc
  6. +4
    -1
      ge/graph/manager/graph_manager_utils.h
  7. +10
    -0
      ge/graph/passes/link_gen_mask_nodes_pass.cc
  8. +2
    -4
      ge/hybrid/executor/hybrid_model_async_executor.cc
  9. +1
    -1
      ge/ir_build/ge_ir_build.cc
  10. +6
    -4
      ge/model/ge_root_model.h
  11. +99
    -71
      ge/session/omg.cc
  12. +1
    -1
      metadef
  13. +1
    -1
      parser
  14. +2
    -2
      tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc

+ 6
- 0
build.sh View File

@@ -250,6 +250,7 @@ generate_package()
NNENGINE_PATH="plugin/nnengine/ge_config"
OPSKERNEL_PATH="plugin/opskernel"

ACL_LIB=("libgraph.so")
ATC_LIB=("libc_sec.so" "libge_common.so" "libge_compiler.so" "libgraph.so" "libregister.so" "liberror_manager.so")
FWK_LIB=("libge_common.so" "libge_runner.so" "libgraph.so" "libregister.so" "liberror_manager.so")
PLUGIN_OPSKERNEL=("libge_local_engine.so" "libge_local_opskernel_builder.so" "libhost_cpu_engine.so" "libhost_cpu_opskernel_builder.so" "optimizer_priority.pbtxt")
@@ -303,6 +304,11 @@ generate_package()
find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \;
done

for lib in "${ACL_LIB[@]}";
do
find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ACL_PATH} \;
done

for lib in "${ATC_LIB[@]}";
do
find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \;


+ 2
- 1
ge/CMakeLists.txt View File

@@ -11,6 +11,8 @@ elseif (ENABLE_D)
endif ()

set(PROTO_LIST
"${METADEF_DIR}/proto/dump_task.proto"
"${METADEF_DIR}/proto/op_mapping_info.proto"
"${METADEF_DIR}/proto/fusion_model.proto"
"${GE_CODE_DIR}/ge/proto/optimizer_priority.proto"
)
@@ -25,7 +27,6 @@ set(PROTO_HEADER_LIST
"${METADEF_DIR}/proto/insert_op.proto"
"${METADEF_DIR}/proto/ge_ir.proto"
"${METADEF_DIR}/proto/fwk_adapter.proto"
"${METADEF_DIR}/proto/op_mapping_info.proto"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})


+ 40
- 1
ge/executor/CMakeLists.txt View File

@@ -3,12 +3,46 @@ set(PROTO_LIST
"${METADEF_DIR}/proto/ge_ir.proto"
"${METADEF_DIR}/proto/insert_op.proto"
"${METADEF_DIR}/proto/task.proto"
"${METADEF_DIR}/proto/op_mapping_info.proto"
)

set(DUMP_PROTO_LIST
"${METADEF_DIR}/proto/dump_task.proto"
"${METADEF_DIR}/proto/op_mapping_info.proto"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST})
protobuf_generate(ge_dump DUMP_PROTO_SRCS DUMP_PROTO_HDRS ${DUMP_PROTO_LIST})

############ libge_proto_dump.a ############
add_library(ge_proto_dump STATIC
${DUMP_PROTO_SRCS}
)

target_compile_definitions(ge_proto_dump PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
google=ascend_private
$<IF:$<STREQUAL:${TARGET_SYSTEM_NAME},Windows>,OS_TYPE=WIN,OS_TYPE=0>
$<$<STREQUAL:${TARGET_SYSTEM_NAME},Windows>:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX>
)

target_compile_options(ge_proto_dump PRIVATE
-O2
-fno-common
-fvisibility=hidden
$<$<AND:$<STREQUAL:${TARGET_SYSTEM_NAME},Windows>,$<STREQUAL:${CMAKE_CONFIGURATION_TYPES},Debug>>:/MTd>
$<$<AND:$<STREQUAL:${TARGET_SYSTEM_NAME},Windows>,$<STREQUAL:${CMAKE_CONFIGURATION_TYPES},Release>>:/MT>
)

target_link_libraries(ge_proto_dump PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf_static
)

target_link_options(ge_proto_dump PRIVATE
-Wl,-Bsymbolic
)
##################################################################

set(SRC_LIST
"ge_executor.cc"
@@ -194,6 +228,7 @@ target_include_directories(ge_executor SYSTEM PRIVATE
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge_dump
${CMAKE_BINARY_DIR}/proto/ge_static
#### yellow zone ####
${GE_CODE_DIR}/../inc
@@ -206,6 +241,7 @@ target_link_libraries(ge_executor PRIVATE
$<BUILD_INTERFACE:intf_pub>
json
ascend_protobuf_static
ge_proto_dump
c_sec
$<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt>
-ldl
@@ -240,6 +276,7 @@ target_include_directories(ge_executor_shared PRIVATE
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge_dump
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
@@ -250,12 +287,14 @@ target_include_directories(ge_executor_shared PRIVATE

target_link_options(ge_executor_shared PRIVATE
-Wl,-Bsymbolic
-Wl,--exclude-libs,ALL
)

target_link_libraries(ge_executor_shared PRIVATE
$<BUILD_INTERFACE:intf_pub>
msprofiler
static_mmpa
ge_proto_dump
-Wl,--no-as-needed
ge_common
runtime


+ 1
- 5
ge/graph/load/model_manager/model_manager.cc View File

@@ -341,11 +341,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge

mmTimespec timespec = mmGetTickCount();
std::shared_ptr<DavinciModel> davinci_model = MakeShared<DavinciModel>(0, listener);
if (davinci_model == nullptr) {
REPORT_CALL_ERROR("E19999", "New DavinciModel fail, model_id:%u", model_id);
GELOGE(FAILED, "davinci_model is nullptr");
return FAILED;
}
GE_CHECK_NOTNULL(davinci_model);
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano +
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond
davinci_model->SetId(model_id);


+ 10
- 0
ge/graph/manager/graph_manager.cc View File

@@ -125,6 +125,7 @@ const uint32_t kInitGraphCount = 1;
const uint32_t kNotAdded = 0;
const uint32_t kStartAdd = 1;
const uint32_t kDoneAdded = 2;
const uint32_t kNeverLoaded = 0;

bool IsTailingOptimization() {
string is_tailing_optimization_option;
@@ -2750,6 +2751,15 @@ void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph
GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id);
}
graph_node->SetLoadFlag(false);
// Allow model to be loaded agagin without adding graph again
graph_node->SetLoadCount(graph_node->GetLoadRecord());
graph_node->SetLoadRecord(kNeverLoaded);
GeRootModelPtr ge_root_model = graph_node->GetGeRootModel();
if (ge_root_model == nullptr) {
GELOGW("ge_root_model is null, graph_id:%u", graph_id);
return;
}
ge_root_model->ClearAllModelId();
rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s",


+ 4
- 1
ge/graph/manager/graph_manager_utils.h View File

@@ -178,9 +178,12 @@ class GraphNode {
void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); }

uint32_t GetLoadCount() const { return load_count_; }
void SetLoadCount(uint32_t count) { load_count_ = count; }
uint32_t GetLoadRecord() const { return load_record_; }
void SetLoadRecord(uint32_t record) { load_record_ = record; }
void IncreaseLoadRecord() { ++load_record_; }
void IncreaseLoadCount();
void DecreaseLoadCount() { --load_count_; }
void IncreaseLoadRecord() { ++load_record_; }

// run graph asynchronous listener
std::shared_ptr<RunAsyncListener> graph_run_async_listener_;


+ 10
- 0
ge/graph/passes/link_gen_mask_nodes_pass.cc View File

@@ -107,6 +107,16 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node
auto in_data_nodes = node->GetInDataNodes();
if (in_data_nodes.size() > kGenMaskInputIndex) {
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex);
for (auto &in_data_node : in_data_nodes) {
// node gen_mask is located at different place in the fused node
if (in_data_node->GetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) {
gen_mask = in_data_node;
GELOGD("The fused node type [%s], paired with the input node name [%s].",
node->GetType().c_str(), gen_mask->GetName().c_str());
break;
}
}

if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) {
continue;
}


+ 2
- 4
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -150,10 +150,8 @@ Status HybridModelAsyncExecutor::RunInternal() {
Status ret = data_inputer_->Pop(data_wrapper);
// Model indeedly start running
SetRunningFlag(true);
if (data_wrapper == nullptr || ret != SUCCESS) {
GELOGI("data_wrapper is null!, ret = %u", ret);
continue;
}
GE_IF_BOOL_EXEC(data_wrapper == nullptr || ret != SUCCESS, GELOGI("data_wrapper is null!, ret = %u", ret);
continue);

GELOGI("Getting the input data, model_id:%u", model_id_);
GE_IF_BOOL_EXEC(!run_flag_, break);


+ 1
- 1
ge/ir_build/ge_ir_build.cc View File

@@ -574,7 +574,7 @@ graphStatus Impl::InitDomiOmgContext(const string &input_shape, const string &in
}

if (!ParseInputShape(input_shape, omg_context_.input_dims, omg_context_.user_input_dims, is_dynamic_input)) {
GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShape:input_shape] Failed, shape: %s", input_shape.c_str());
GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShape:InputShape] Failed, shape: %s", input_shape.c_str());
return GRAPH_PARAM_INVALID;
}
return GRAPH_SUCCESS;


+ 6
- 4
ge/model/ge_root_model.h View File

@@ -40,12 +40,14 @@ class GeRootModel {
}
uint32_t GetModelId() const { return model_id_; }

std::vector<uint32_t> GetAllModelId() const { return model_ids_; }

void SetModelName(const std::string &model_name) { model_name_ = model_name; }

const std::string &GetModelName() const { return model_name_; }

std::vector<uint32_t> GetAllModelId() const { return model_ids_; }

void ClearAllModelId() { model_ids_.clear(); }

Status CheckIsUnknownShape(bool &is_dynamic_shape);

void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; }


+ 99
- 71
ge/session/omg.cc View File

@@ -86,7 +86,8 @@ static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_p
return true;
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {atc_param, s});
GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Input parameter[--%s]'s value[%s] must be true or false.",
atc_param.c_str(), s.c_str());
return false;
}
}
@@ -110,9 +111,8 @@ static Status CheckInputShapeNode(const ComputeGraphPtr &graph, bool is_dynamic_
GE_CHECK_NOTNULL(tensor_desc);
for (auto dim : tensor_desc->GetShape().GetDims()) {
if (dim < 0) {
GELOGE(PARAM_INVALID,
"Input op [%s] shape %ld is negative, maybe you should set input_shape to specify its shape",
node->GetName().c_str(), dim);
GELOGE(PARAM_INVALID, "[Check][Param]Input op [%s] shape %ld is negative, "
"maybe you should set input_shape to specify its shape", node->GetName().c_str(), dim);
const string reason = "maybe you should set input_shape to specify its shape";
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{node->GetName(), to_string(dim), reason});
@@ -127,12 +127,14 @@ static Status CheckInputShapeNode(const ComputeGraphPtr &graph, bool is_dynamic_
ge::NodePtr node = graph->FindNode(node_name);
if (node == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name});
GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not exist in model", node_name.c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Input parameter[--input_shape]'s opname[%s] is not exist in model",
node_name.c_str());
return PARAM_INVALID;
}
if (node->GetType() != DATA) {
ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name});
GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not a input opname", node_name.c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Input parameter[--input_shape]'s opname[%s] is not a input opname",
node_name.c_str());
return PARAM_INVALID;
}
}
@@ -160,8 +162,8 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in
for (auto &s : adjust_fp16_format_vec) {
StringUtils::Trim(s);
if (!CheckInputTrueOrFalse(s, "is_input_adjust_hw_layout")) {
GELOGE(PARAM_INVALID, "Invalid Param, is_input_adjust_hw_layout only support true/false: but is [%s]",
is_input_adjust_hw_layout.c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Invalid Param, is_input_adjust_hw_layout only support true/false:"
"but is [%s]", is_input_adjust_hw_layout.c_str());
return PARAM_INVALID;
}
}
@@ -176,7 +178,7 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in
if (node == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"input_fp16_nodes", input_fp16_nodes_vec[i]});
GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not exist in model",
GELOGE(PARAM_INVALID, "[Check][Param]Input parameter[--input_fp16_nodes]'s opname[%s] is not exist in model",
input_fp16_nodes_vec[i].c_str());
return PARAM_INVALID;
}
@@ -185,7 +187,7 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in
if (op_desc->GetType() != DATA) {
ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"},
{"input_fp16_nodes", input_fp16_nodes_vec[i]});
GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not a input opname",
GELOGE(PARAM_INVALID, "[Check][Param]Input parameter[--input_fp16_nodes]'s opname[%s] is not a input opname",
input_fp16_nodes_vec[i].c_str());
return PARAM_INVALID;
}
@@ -205,8 +207,8 @@ static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) {
for (auto &is_fp16 : node_format_vec) {
StringUtils::Trim(is_fp16);
if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) {
GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]",
is_output_fp16.c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Invalid Param, is_output_adjust_hw_layout "
"only support true/false: but is [%s]", is_output_fp16.c_str());
return PARAM_INVALID;
}
if (is_fp16 == "false") {
@@ -263,7 +265,8 @@ void FindParserSo(const string &path, vector<string> &file_list, string &caffe_p

Status SetOutFormatAndDataTypeAttr(ge::OpDescPtr op_desc, const ge::Format format, const ge::DataType data_type) {
if (op_desc == nullptr) {
GELOGE(domi::FAILED, "Input op desc invalid.");
REPORT_INNER_ERROR("E19999", "param op_desc is nullptr, check invalid.");
GELOGE(domi::FAILED, "[Check][Param]Input op desc invalid.");
return domi::FAILED;
}
(void)ge::AttrUtils::SetInt(op_desc, ATTR_NAME_NET_OUTPUT_FORMAT, format);
@@ -274,7 +277,7 @@ Status SetOutFormatAndDataTypeAttr(ge::OpDescPtr op_desc, const ge::Format forma
bool CheckDigitStr(std::string &str) {
for (char c : str) {
if (!isdigit(c)) {
GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str());
GELOGE(domi::FAILED, "[Check][Param]value[%s] is not positive integer", str.c_str());
return false;
}
}
@@ -284,18 +287,18 @@ bool CheckDigitStr(std::string &str) {
Status StringToInt(std::string &str, int32_t &value) {
try {
if (!CheckDigitStr(str)) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Invalid of digit string: %s ", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", str, "is not positive integer"});
return PARAM_INVALID;
}
value = stoi(str);
} catch (std::invalid_argument &) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Invalid of digit string: %s, catch invalid_argument.", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"--output_type", str});
return PARAM_INVALID;
} catch (std::out_of_range &) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Invalid of digit string: %s, catch out_of_range.", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"--output_type", str});
return PARAM_INVALID;
}
@@ -314,7 +317,8 @@ Status VerifyOutputTypeAndOutNodes(std::vector<std::string> &out_type_vec) {
if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", out_type_vec[i], kOutputTypeError});
GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError);
GELOGE(domi::FAILED, "[Check][Param]Invalid value for --output_type[%s], %s.",
out_type_vec[i].c_str(), kOutputTypeError);
return domi::FAILED;
}
}
@@ -326,7 +330,8 @@ Status CheckOutPutDataTypeSupport(const std::string &output_type) {
if (it == output_type_str_to_datatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", output_type, kOutputTypeSupport});
GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport);
GELOGE(PARAM_INVALID, "[Check][Param]Invalid value for --output_type[%s], %s.",
output_type.c_str(), kOutputTypeSupport);
return domi::FAILED;
}
return domi::SUCCESS;
@@ -344,7 +349,7 @@ Status ParseOutputType(const std::string &output_type, std::map<std::string, vec
if (node_index_type_v.size() != 3) { // The size must be 3.
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", node, kOutputTypeSample});
GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample);
GELOGE(PARAM_INVALID, "[Parse][Param]Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample);
return domi::FAILED;
}
ge::DataType tmp_dt;
@@ -352,7 +357,8 @@ Status ParseOutputType(const std::string &output_type, std::map<std::string, vec
std::string index_str = StringUtils::Trim(node_index_type_v[kIndexStrIndex]);
int32_t index;
if (StringToInt(index_str, index) != SUCCESS) {
GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str());
GELOGE(PARAM_INVALID, "[Convert][Type]This str must be digit string, while the actual input is %s.",
index_str.c_str());
return domi::FAILED;
}
std::string dt_value = StringUtils::Trim(node_index_type_v[kDTValueIndex]);
@@ -360,7 +366,8 @@ Status ParseOutputType(const std::string &output_type, std::map<std::string, vec
if (it == output_type_str_to_datatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", dt_value, kOutputTypeSupport});
GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport);
GELOGE(ge::PARAM_INVALID, "[Parse][Param]Invalid value for --output_type[%s], %s.",
dt_value.c_str(), kOutputTypeSupport);
return domi::FAILED;
} else {
tmp_dt = it->second;
@@ -383,7 +390,7 @@ Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) {
int32_t out_size = op_desc->GetOutputsSize();
if (index < 0 || index >= out_size) {
GELOGE(domi::FAILED,
"out_node [%s] output index:%d must be smaller "
"[Check][Param]out_node [%s] output index:%d must be smaller "
"than node output size:%d and can not be negative!",
op_desc->GetName().c_str(), index, out_size);
std::string fail_reason = "output index:" + to_string(index) + " must be smaller than output size:" +
@@ -403,7 +410,7 @@ Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
if (out_node == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"out_nodes", default_out_nodes[i].first});
GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", default_out_nodes[i].first.c_str());
GELOGE(domi::FAILED, "[Check][Param]Can not find src node (%s) in graph.", default_out_nodes[i].first.c_str());
return domi::FAILED;
}
output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second));
@@ -432,7 +439,7 @@ Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const
std::map<std::string, vector<std::string>> output_node_dt_map;
if (!output_type.empty()) {
if (ParseOutputType(output_type, output_node_dt_map) != SUCCESS) {
GELOGE(domi::FAILED, "Parse output_type failed.");
GELOGE(domi::FAILED, "[Parse][output_type] failed.");
return domi::FAILED;
}
}
@@ -443,13 +450,13 @@ Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const
if (out_node == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"out_nodes", user_out_nodes[i].first});
GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str());
GELOGE(domi::FAILED, "[Check][Param]Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str());
return domi::FAILED;
}
auto op_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) {
GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str());
GELOGE(domi::FAILED, "[Check][OutNode] (%s) fail.", user_out_nodes[i].first.c_str());
return domi::FAILED;
}

@@ -475,7 +482,7 @@ Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const
// default output node (leaf)
if (user_out_nodes.empty()) {
if (GetDefaultOutInfo(compute_graph, output_nodes_info) != SUCCESS) {
GELOGE(domi::FAILED, "Get default output info failed.");
GELOGE(domi::FAILED, "[Get][DefaultOutInfo] failed.");
return domi::FAILED;
}
}
@@ -513,7 +520,8 @@ void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &ou
Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
ge::OpDescPtr tmpDescPtr = node->GetOpDesc();
if (tmpDescPtr == nullptr) {
GELOGE(domi::FAILED, "Get outnode op desc fail.");
REPORT_INNER_ERROR("E19999", "param node has no opdesc.");
GELOGE(domi::FAILED, "[Check][Param]Get outnode op desc fail.");
return domi::FAILED;
}
size_t size = tmpDescPtr->GetOutputsSize();
@@ -527,7 +535,8 @@ Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>>
for (auto in_anchor : in_anchors) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
GELOGE(domi::FAILED, "Get leaf node op desc fail.");
REPORT_INNER_ERROR("E19999", "GetPeerOutAnchor return nullptr, node:%s.", node->GetName().c_str());
GELOGE(domi::FAILED, "[Invoke][GetPeerOutAnchor]Get leaf node op desc fail.");
return domi::FAILED;
}
auto out_node = out_anchor->GetOwnerNode();
@@ -557,8 +566,10 @@ Status InitDomiOmgContext(const string &input_shape, const string &input_format,
if (iter != ge::input_format_str_to_geformat.end()) {
domi::GetContext().format = iter->second;
} else {
GELOGE(PARAM_INVALID, "Input format %s not support , expect ND/NCHW/NHWC/CHWN/NC1HWC0/NHWC1C0.",
input_format.c_str());
REPORT_INNER_ERROR("E19999", "param input_format:%s is not support, "
"expect ND/NCHW/NHWC/CHWN/NC1HWC0/NHWC1C0.", input_format.c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Input format %s not support, "
"expect ND/NCHW/NHWC/CHWN/NC1HWC0/NHWC1C0.", input_format.c_str());
return PARAM_INVALID;
}
}
@@ -572,9 +583,9 @@ Status InitDomiOmgContext(const string &input_shape, const string &input_format,
map<string, vector<int64_t>> &shape_map = domi::GetContext().input_dims;

if (!ge::ParseInputShape(input_shape, domi::GetContext().input_dims, domi::GetContext().user_input_dims,
is_dynamic_input) ||
shape_map.empty()) {
GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str());
is_dynamic_input) || shape_map.empty()) {
REPORT_CALL_ERROR("E19999", "ParseInputShape failed for %s", input_shape.c_str());
GELOGE(PARAM_INVALID, "[Parse][InputShape] %s failed.", input_shape.c_str());
return PARAM_INVALID;
}

@@ -601,7 +612,7 @@ Status ParseOutNodes(const string &out_nodes) {
"E10001", {"parameter", "value", "reason"},
{"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""});
GELOGE(PARAM_INVALID,
"The input format of --out_nodes is invalid, the correct format is "
"[Parse][Param]The input format of --out_nodes is invalid, the correct format is "
"\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.",
node.c_str());
return PARAM_INVALID;
@@ -609,15 +620,16 @@ Status ParseOutNodes(const string &out_nodes) {
if (!domi::GetContext().user_out_nodes_top_vec.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--out_nodes", out_nodes, "is not all index or top_name"});
GELOGE(PARAM_INVALID,
"This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str());
GELOGE(PARAM_INVALID, "[Parse][Param]This out_nodes str must be all index or top_name, "
"while the actual input is %s", out_nodes.c_str());
return PARAM_INVALID;
}
// stoi: The method may throw an exception: invalid_argument/out_of_range
if (!CheckDigitStr(key_value_v[1])) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--out_nodes", out_nodes, "is not positive integer"});
GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str());
GELOGE(PARAM_INVALID, "[Parse][Param]This str must be digit string, while the actual input is %s",
out_nodes.c_str());
return PARAM_INVALID;
}

@@ -635,11 +647,11 @@ Status ParseOutNodes(const string &out_nodes) {
}
}
} catch (std::invalid_argument &) {
GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
GELOGE(PARAM_INVALID, "[Parse][Param]Invalid of out_nodes: %s ", out_nodes.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"--out_nodes", out_nodes});
return PARAM_INVALID;
} catch (std::out_of_range &) {
GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
GELOGE(PARAM_INVALID, "[Parse][Param]Invalid of out_nodes: %s ", out_nodes.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"--out_nodes", out_nodes});
return PARAM_INVALID;
}
@@ -657,7 +669,8 @@ static Status CheckOpNameMap(const ComputeGraphPtr &graph, const std::string &op
for (const NodePtr &node : graph->GetAllNodes()) {
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(PARAM_INVALID, "Invalid parameter for opDesc.");
REPORT_INNER_ERROR("E19999", "param graph's node has no opdesc.");
GELOGE(PARAM_INVALID, "[Check][Param]Invalid parameter for opDesc.");
return PARAM_INVALID;
}
graphNodeTypes[op_desc->GetType()] = "";
@@ -666,7 +679,7 @@ static Status CheckOpNameMap(const ComputeGraphPtr &graph, const std::string &op
if (propertiesMap.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10003", {"parameter", "value", "reason"}, {"op_name_map", op_conf, "the file content is empty"});
GELOGE(PARAM_INVALID, "op_name_map file content is empty, please check file!");
GELOGE(PARAM_INVALID, "[Check][Param]op_name_map file content is empty, please check file!");
return PARAM_INVALID;
}
for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) {
@@ -674,7 +687,8 @@ static Status CheckOpNameMap(const ComputeGraphPtr &graph, const std::string &op
ErrorManager::GetInstance().ATCReportErrMessage(
"E10003", {"parameter", "value", "reason"},
{"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"});
GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map."); return PARAM_INVALID;);
GELOGE(PARAM_INVALID, "[Find][NodeType]Invalid parameter for op_name_map.");
return PARAM_INVALID;);
}
return SUCCESS;
}
@@ -711,15 +725,16 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri
std::string input_format;
ParseAtcParms(atc_params, "input_format", input_format);
GE_RETURN_WITH_LOG_IF_ERROR(InitDomiOmgContext(input_shape, input_format, "", is_dynamic_input),
"ATC Generate call InitDomiOmgContext ret fail");
"[Call][InitDomiOmgContext] ret fail");

std::string is_output_adjust_hw_layout;
ParseAtcParms(atc_params, "is_output_adjust_hw_layout", is_output_adjust_hw_layout);
GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputFp16NodesFormat(is_output_adjust_hw_layout), "Parse is_output_fp16 failed");
GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputFp16NodesFormat(is_output_adjust_hw_layout),
"[Call][ParseOutputFp16NodesFormat]Parse is_output_fp16 failed");

std::string out_nodes;
ParseAtcParms(atc_params, "out_nodes", out_nodes);
GE_RETURN_WITH_LOG_IF_ERROR(ParseOutNodes(out_nodes), "ATC Generate parse out nodes fail");
GE_RETURN_WITH_LOG_IF_ERROR(ParseOutNodes(out_nodes), "[Parse][OutNodes] fail");

std::string output_type;
ParseAtcParms(atc_params, "output_type", output_type);
@@ -732,15 +747,19 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri
GE_IF_BOOL_EXEC(!PropertiesManager::Instance().Init(op_conf),
ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"},
{"op_name_map", op_conf, "file content error"});
GELOGE(FAILED, "op_name_map init failed!"); return FAILED);
GELOGE(FAILED, "[Invoke][Init]op_name_map init failed!");
return FAILED);
// Return map and put it into ATC global variable
domi::GetContext().op_conf_map = PropertiesManager::Instance().GetPropertyMap();
}

// parse network model
auto model_parser = ModelParserFactory::Instance()->CreateModelParser(type);
GE_CHK_BOOL_RET_STATUS(model_parser != nullptr, FAILED, "ATC create model parser ret fail, type:%d.", type);

if (model_parser == nullptr) {
REPORT_INNER_ERROR("E19999", "CreateModelParser failed, type:%d", type);
GELOGE(FAILED, "[Create][ModelParser] ret fail, type:%d.", type);
return FAILED;
}
UpdateParserCtxWithOmgCtx();
Status ret = model_parser->Parse(model_file, graph);
UpdateOmgCtxWithParserCtx();
@@ -749,7 +768,8 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri
if (PreChecker::Instance().HasError() || run_mode == ONLY_PRE_CHECK) {
std::string check_report;
ParseAtcParms(atc_params, "check_report", check_report);
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().Save(check_report), "Generate pre-checking report failed.");
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().Save(check_report),
"[Invoke][Save]Generate pre-checking report failed.");
GEEVENT("The pre-checking report has been saved to %s.", check_report.c_str());
}

@@ -768,7 +788,7 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri
// Verify the contents of the op_name_map
if (op_conf != nullptr && *op_conf != '\0') {
GE_RETURN_WITH_LOG_IF_ERROR(CheckOpNameMap(compute_graph, op_conf),
"op_name_map parameter is not fit with input net!");
"[Invoke][CheckOpNameMap]op_name_map parameter is not fit with input net!");
}

// Print parse network structure
@@ -783,17 +803,18 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri
if (PreChecker::Instance().HasError() || run_mode == ONLY_PRE_CHECK) {
std::string check_report;
ParseAtcParms(atc_params, "check_report", check_report);
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().Save(check_report), "Generate pre-checking report failed.");
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().Save(check_report),
"[Invoke][Save]Generate pre-checking report failed.");
GEEVENT("The pre-checking report has been saved to %s.", check_report.c_str());
}
// Prevent data residue in multiple calls
PreChecker::Instance().Clear();

GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC weights parse ret fail.");
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "[Check][State]ATC weights parse ret fail.");

// parser input shape range and update op shape range
GE_RETURN_WITH_LOG_IF_ERROR(UpdateDynamicInputShapeRange(compute_graph, input_shape_range),
"Update input shape range failed");
"[Update][DynamicInputShapeRange] failed");

GELOGI("ATC parser success.");

@@ -915,7 +936,8 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js
// Load model from file
Status ret = ModelParserBase::LoadFromFile(model_file, "", priority, model);
if (ret != SUCCESS) {
GELOGE(ret, "LoadFromFile failed.");
REPORT_CALL_ERROR("E19999", "LoadFromFile failed, file:%s", model_file);
GELOGE(ret, "[Invoke][LoadFromFile] failed.");
return ret;
}

@@ -929,7 +951,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js
ge::graphStatus status = omFileLoadHelper.Init(model_data, model_len);
if (status != ge::GRAPH_SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Om file init failed"});
GELOGE(ge::FAILED, "Om file init failed.");
GELOGE(ge::FAILED, "[Invoke][Init]Om file init failed.");
if (model.model_data != nullptr) {
delete[] reinterpret_cast<char *>(model.model_data);
model.model_data = nullptr;
@@ -941,7 +963,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js
status = omFileLoadHelper.GetModelPartition(MODEL_DEF, ir_part);
if (status != ge::GRAPH_SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Get model part failed"});
GELOGE(ge::FAILED, "Get model part failed.");
GELOGE(ge::FAILED, "[Get][ModelPartition] failed.");
if (model.model_data != nullptr) {
delete[] reinterpret_cast<char *>(model.model_data);
model.model_data = nullptr;
@@ -967,13 +989,13 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js
} else {
ret = INTERNAL_ERROR;
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"ReadProtoFromArray failed"});
GELOGE(ret, "ReadProtoFromArray failed.");
GELOGE(ret, "[Read][Proto]From Array failed.");
}
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10003",
{"parameter", "value", "reason"}, {"om", model_file, "invalid om file"});
GELOGE(ACL_ERROR_GE_PARAM_INVALID,
"ParseModelContent failed because of invalid om file. Please check --om param.");
"[Parse][ModelContent] failed because of invalid om file. Please check --om param.");
}

if (model.model_data != nullptr) {
@@ -984,7 +1006,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js
} catch (const std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"},
{"Convert om model to json failed, exception message[" + std::string(e.what()) + "]"});
GELOGE(FAILED, "Convert om model to json failed, exception message : %s.", e.what());
GELOGE(FAILED, "[Save][Model]Convert om model to json failed, exception message : %s.", e.what());
return FAILED;
}
}
@@ -1003,7 +1025,8 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const
};
if (ret != SUCCESS) {
free_model_data(&model.model_data);
GELOGE(ret, "LoadFromFile failed.");
REPORT_CALL_ERROR("E19999", "LoadFromFile failed.");
GELOGE(ret, "[Invoke][LoadFromFile] failed.");
return ret;
}

@@ -1015,7 +1038,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const
if (!flag) {
free_model_data(&model.model_data);
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"ParseFromString failed"});
GELOGE(FAILED, "ParseFromString failed.");
GELOGE(FAILED, "[Invoke][ParseFromString] failed.");
return FAILED;
}
GetGroupName(model_def);
@@ -1024,7 +1047,8 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const
ret = ModelSaver::SaveJsonToFile(json_file, j);
if (ret != SUCCESS) {
free_model_data(&model.model_data);
GELOGE(ret, "Save json to file fail.");
REPORT_CALL_ERROR("E19999", "SaveJsonToFile failed.");
GELOGE(ret, "[Save][Json] to file fail.");
return ret;
}
free_model_data(&model.model_data);
@@ -1033,12 +1057,12 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const
free_model_data(&model.model_data);
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"ParseFromString failed, exception message["
+ std::string(e.what()) + "]"});
GELOGE(FAILED, "ParseFromString failed. exception message : %s", e.what());
GELOGE(FAILED, "[Invoke][ParseFromString] failed. exception message : %s", e.what());
return FAILED;
} catch (const std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"},
{"Convert pbtxt to json failed, exception message[" + std::string(e.what()) + "]"});
GELOGE(FAILED, "Convert pbtxt to json failed, exception message : %s.", e.what());
GELOGE(FAILED, "[Save][pbtxt]Convert pbtxt to json failed, exception message : %s.", e.what());
return FAILED;
}
}
@@ -1047,16 +1071,19 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertFwkModelToJson(const domi::FrameworkType
const char *json_file) {
if (framework == domi::CAFFE || framework == domi::TENSORFLOW || framework == domi::ONNX) {
auto model_parser = ModelParserFactory::Instance()->CreateModelParser(framework);
GE_CHK_BOOL_RET_STATUS(model_parser != nullptr, FAILED, "ATC create model parser ret fail, framework:%d.",
framework);
if (model_parser == nullptr) {
REPORT_INNER_ERROR("E19999", "CreateModelParser failed, framework:%d.", framework);
GELOGE(FAILED, "[Create][ModelParser] ret fail, framework:%d.", framework);
return FAILED;
}
return model_parser->ToJson(model_file, json_file);
}

ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--framework", std::to_string(framework), "only support 0(Caffe) 3(TensorFlow) 5(Onnx)"});
GELOGE(PARAM_INVALID, "Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow) "
"or 5(Onnx).");
GELOGE(PARAM_INVALID, "[Check][Param]Input parameter[--framework] is mandatory "
"and it's value must be: 0(Caffe) 3(TensorFlow) or 5(Onnx).");
return PARAM_INVALID;
}

@@ -1072,7 +1099,8 @@ FMK_FUNC_HOST_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const
if (buffer.GetData() != nullptr) {
std::string str(reinterpret_cast<const char *>(buffer.GetData()), buffer.GetSize());
if (!ge_proto.ParseFromString(str)) {
GELOGE(GRAPH_FAILED, "parse from string failed.");
REPORT_CALL_ERROR("E19999", "ParseFromString failed.");
GELOGE(GRAPH_FAILED, "[Invoke][ParseFromString] failed.");
return FAILED;
}



+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 72984c62e434da09f8ea7d7a4fd4a365e0b6f42c
Subproject commit 4cf2633a8f2290dee165ab11f8d6b8a07cba1412

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit 88c66a63b382bb54b338859e44b6321da979b22e
Subproject commit a41249dc9b50e4c4988eb62a662b7df29ac24ee7

+ 2
- 2
tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc View File

@@ -43,7 +43,7 @@ ut::GraphBuilder Graph1Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1");
auto const1 = builder.AddNode("const1", "Const", 0, 1);
auto const2 = builder.AddNode("const2", "Const", 0, 1);
auto gen_mask1 = builder.AddNode("gen_mask1", "DropOutGenMask", 2, 1);
auto gen_mask1 = builder.AddNode("gen_mask1_DropOutGenMask", "DropOutGenMask", 2, 1);
auto gen_mask2 = builder.AddNode("gen_mask2", "DropOutGenMaskV3", 2, 1);
auto gen_mask3 = builder.AddNode("gen_mask3", "DropOutGenMaskV3D", 2, 1);
auto do_mask1 = builder.AddNode("do_mask1", "DropOutDoMask", 3, 1);
@@ -106,6 +106,6 @@ TEST_F(UtestLinkGenMaskNodesPass, link_gen_mask_nodes_pass_success) {
auto out_ctrl_nodes = gen_mask2->GetOutControlNodes();
EXPECT_EQ(out_ctrl_nodes.size(), 1);
auto out_ctrl_node = out_ctrl_nodes.at(0);
EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1");
EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1_DropOutGenMask");
}
} // namespace ge

Loading…
Cancel
Save