@@ -3,8 +3,17 @@ | |||||
/output | /output | ||||
/prebuilts | /prebuilts | ||||
/cov | /cov | ||||
/deps | |||||
.autotools | |||||
.project | |||||
.cproject | |||||
.settings/ | |||||
/tests/frm/ | |||||
*.ir | *.ir | ||||
*.out | *.out | ||||
*.DS_Store | |||||
.DS_Store | |||||
server_config.sh | |||||
# Dynamic libraries | # Dynamic libraries | ||||
# *.so | # *.so | ||||
@@ -69,10 +69,10 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||||
auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); | auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); | ||||
if (job_id_iter != options.end()) { | if (job_id_iter != options.end()) { | ||||
if (job_id_iter->second.length() > kMaxStrLen) { | if (job_id_iter->second.length() > kMaxStrLen) { | ||||
GELOGE(PARAM_INVALID,"[Check][JobId]Failed," | |||||
GELOGE(PARAM_INVALID, "[Check][JobId]Failed," | |||||
"the job_id [%s] string length: %zu > max string length: %d", | "the job_id [%s] string length: %zu > max string length: %d", | ||||
job_id_iter->second.c_str(), job_id_iter->second.length(), kMaxStrLen); | job_id_iter->second.c_str(), job_id_iter->second.length(), kMaxStrLen); | ||||
REPORT_INPUT_ERROR("E10051", std::vector<std::string>({"id","length"}), | |||||
REPORT_INPUT_ERROR("E10051", std::vector<std::string>({"id", "length"}), | |||||
std::vector<std::string>({job_id_iter->second, | std::vector<std::string>({job_id_iter->second, | ||||
std::to_string(kMaxStrLen)})); | std::to_string(kMaxStrLen)})); | ||||
return FAILED; | return FAILED; | ||||
@@ -244,7 +244,7 @@ std::string GEGetWarningMsg() { | |||||
// Initialize session,which calls innerSession | // Initialize session,which calls innerSession | ||||
Session::Session(const std::map<string, string> &options) { | Session::Session(const std::map<string, string> &options) { | ||||
ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); | ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); | ||||
GELOGT(TRACE_INIT, "Session Constructor start"); | |||||
GELOGT(TRACE_INIT, "Start to construct session."); | |||||
ErrorManager::GetInstance().GenWorkStreamIdDefault(); | ErrorManager::GetInstance().GenWorkStreamIdDefault(); | ||||
// check init status | // check init status | ||||
@@ -332,7 +332,7 @@ Session::Session(const std::map<AscendString, AscendString> &options) { | |||||
// session destructor | // session destructor | ||||
Session::~Session() { | Session::~Session() { | ||||
ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize); | ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize); | ||||
GELOGT(TRACE_INIT, "Session Destructor start"); | |||||
GELOGT(TRACE_INIT, "Start to destruct session."); | |||||
// 0.check init status | // 0.check init status | ||||
if (!g_ge_initialized) { | if (!g_ge_initialized) { | ||||
GELOGW("GE is not yet initialized or is finalized."); | GELOGW("GE is not yet initialized or is finalized."); | ||||
@@ -602,16 +602,16 @@ Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, s | |||||
Status Session::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector<Tensor> &inputs, | Status Session::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector<Tensor> &inputs, | ||||
std::vector<Tensor> &outputs) { | std::vector<Tensor> &outputs) { | ||||
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | ||||
GELOGT(TRACE_INIT, "Session run graph with stream async start"); | |||||
GELOGT(TRACE_INIT, "Start to run graph with stream async."); | |||||
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | ||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | ||||
if (instance_ptr == nullptr) { | if (instance_ptr == nullptr) { | ||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, | GELOGE(GE_CLI_GE_NOT_INITIALIZED, | ||||
"[Run][Graph]Run graph with stream asyn failed, the GELib instance is nullptr," | |||||
"[Run][Graph]Run graph with stream async failed, the GELib instance is nullptr," | |||||
"session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | ||||
REPORT_INNER_ERROR("E19999", | REPORT_INNER_ERROR("E19999", | ||||
"Run graph with stream asyn failed, the GELib instance is nullptr" | |||||
"Run graph with stream async failed, the GELib instance is nullptr" | |||||
"session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -66,21 +66,21 @@ void DumpOp::SetDynamicModelInfo(const string &dynamic_model_name, const string | |||||
static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uintptr_t loop_cond, | static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uintptr_t loop_cond, | ||||
toolkit::aicpu::dump::OpMappingInfo &op_mapping_info) { | toolkit::aicpu::dump::OpMappingInfo &op_mapping_info) { | ||||
if (step_id != 0) { | if (step_id != 0) { | ||||
GELOGI("step_id exists."); | |||||
GELOGI("Exists step_id."); | |||||
op_mapping_info.set_step_id_addr(static_cast<uint64_t>(step_id)); | op_mapping_info.set_step_id_addr(static_cast<uint64_t>(step_id)); | ||||
} else { | } else { | ||||
GELOGI("step_id is null."); | GELOGI("step_id is null."); | ||||
} | } | ||||
if (loop_per_iter != 0) { | if (loop_per_iter != 0) { | ||||
GELOGI("loop_per_iter exists."); | |||||
GELOGI("Exists loop_per_iter."); | |||||
op_mapping_info.set_iterations_per_loop_addr(static_cast<uint64_t>(loop_per_iter)); | op_mapping_info.set_iterations_per_loop_addr(static_cast<uint64_t>(loop_per_iter)); | ||||
} else { | } else { | ||||
GELOGI("loop_per_iter is null."); | GELOGI("loop_per_iter is null."); | ||||
} | } | ||||
if (loop_cond != 0) { | if (loop_cond != 0) { | ||||
GELOGI("loop_cond exists."); | |||||
GELOGI("Exists loop_cond."); | |||||
op_mapping_info.set_loop_cond_addr(static_cast<uint64_t>(loop_cond)); | op_mapping_info.set_loop_cond_addr(static_cast<uint64_t>(loop_cond)); | ||||
} else { | } else { | ||||
GELOGI("loop_cond is null."); | GELOGI("loop_cond is null."); | ||||
@@ -253,7 +253,7 @@ Status DumpOp::LaunchDumpOp() { | |||||
} | } | ||||
if (device_id < 0) { | if (device_id < 0) { | ||||
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][DeviceId]Failed, device_id %d", device_id); | GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][DeviceId]Failed, device_id %d", device_id); | ||||
REPORT_INNER_ERROR("E19999","Check device_id %d failed", device_id); | |||||
REPORT_INNER_ERROR("E19999", "Check device_id %d failed", device_id); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR; | return ACL_ERROR_GE_INTERNAL_ERROR; | ||||
} | } | ||||
toolkit::aicpu::dump::OpMappingInfo op_mapping_info; | toolkit::aicpu::dump::OpMappingInfo op_mapping_info; | ||||
@@ -72,8 +72,7 @@ Status CheckArgsForFracZToNchw(const TransArgs &args) { | |||||
if (src_shape.at(kFracZHWC1) != dst_shape.at(kNchwH) * dst_shape.at(kNchwW) * c1 || | if (src_shape.at(kFracZHWC1) != dst_shape.at(kNchwH) * dst_shape.at(kNchwW) * c1 || | ||||
src_shape.at(kFracZC0) != c0 || src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | src_shape.at(kFracZC0) != c0 || src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | ||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | ||||
"[Check][Shape]Failed to check relationship between src and dst shape, " | |||||
"src shape %s, dst shape %s", | |||||
"[Check][Shape]Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
REPORT_INNER_ERROR("E19999", "Failed to check relationship between src and dst shape, " | REPORT_INNER_ERROR("E19999", "Failed to check relationship between src and dst shape, " | ||||
"src shape %s, dst shape %s", | "src shape %s, dst shape %s", | ||||
@@ -138,9 +137,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
"[Operate][Memory]Failed to copy data from FracZ offset %ld to " | "[Operate][Memory]Failed to copy data from FracZ offset %ld to " | ||||
"NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | "NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | ||||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | ||||
REPORT_CALL_ERROR("E19999","Failed to copy data from FracZ offset %ld to " | |||||
REPORT_CALL_ERROR("E19999", "Failed to copy data from FracZ offset %ld to " | |||||
"NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | "NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | ||||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret ); | |||||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | ||||
} | } | ||||
} | } | ||||
@@ -44,23 +44,20 @@ Status CheckArgsForFracZToNhwc(const TransArgs &args) { | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][DataType]Failed, " | GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][DataType]Failed, " | ||||
"shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | "shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | ||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
REPORT_INNER_ERROR("E19999", "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, " | |||||
"invalid data type %s", | |||||
REPORT_INNER_ERROR("E19999", "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return ACL_ERROR_GE_DATATYPE_INVALID; | return ACL_ERROR_GE_DATATYPE_INVALID; | ||||
} | } | ||||
if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | ||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | ||||
ShapeToString(src_shape).c_str()); | ShapeToString(src_shape).c_str()); | ||||
REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
ShapeToString(src_shape).c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", ShapeToString(src_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | return ACL_ERROR_GE_SHAPE_INVALID; | ||||
} | } | ||||
if (!CheckShapeValid(dst_shape, kNhwcDimsNum)) { | if (!CheckShapeValid(dst_shape, kNhwcDimsNum)) { | ||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | ||||
ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
ShapeToString(dst_shape).c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", ShapeToString(dst_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | return ACL_ERROR_GE_SHAPE_INVALID; | ||||
} | } | ||||
int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | ||||
@@ -138,7 +135,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||||
"[Operate][Memory]Failed to copy data from FracZ offset %ld to " | "[Operate][Memory]Failed to copy data from FracZ offset %ld to " | ||||
"NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | "NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | ||||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | ||||
REPORT_CALL_ERROR("E19999","Failed to copy data from FracZ offset %ld to " | |||||
REPORT_CALL_ERROR("E19999", "Failed to copy data from FracZ offset %ld to " | |||||
"NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | "NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | ||||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | ||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | ||||
@@ -185,7 +182,7 @@ Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult & | |||||
ShapeToString(args.src_shape).c_str(), | ShapeToString(args.src_shape).c_str(), | ||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size, ret); | ShapeToString(args.dst_shape).c_str(), total_size, ret); | ||||
REPORT_CALL_ERROR("E19999","Failed to get data after trans, src shape %s, data type %s, " | |||||
REPORT_CALL_ERROR("E19999", "Failed to get data after trans, src shape %s, data type %s, " | |||||
"dst shape %s, memory size %ld, error_code %u", | "dst shape %s, memory size %ld, error_code %u", | ||||
ShapeToString(args.src_shape).c_str(), | ShapeToString(args.src_shape).c_str(), | ||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
@@ -112,11 +112,10 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
total_size, ShapeToString(args.dst_shape).c_str(), | total_size, ShapeToString(args.dst_shape).c_str(), | ||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str()); | ||||
REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, " | |||||
"shape %s when trans format from %s to %s", | |||||
total_size, ShapeToString(args.dst_shape).c_str(), | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, shape %s when trans format from %s to %s", | |||||
total_size, ShapeToString(args.dst_shape).c_str(), | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | return ACL_ERROR_GE_MEMORY_ALLOCATION; | ||||
} | } | ||||
@@ -47,7 +47,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg | |||||
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Shape]Failed, input data is null " | GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Shape]Failed, input data is null " | ||||
"or shape size not euqal to 0, src_shape %s", | "or shape size not euqal to 0, src_shape %s", | ||||
ShapeToString(args.src_shape).c_str()); | ShapeToString(args.src_shape).c_str()); | ||||
REPORT_CALL_ERROR("E19999","Failed to check shape, input data is null " | |||||
REPORT_CALL_ERROR("E19999", "Failed to check shape, input data is null " | |||||
"or shape size not equal to 0, src_shape %s", | "or shape size not equal to 0, src_shape %s", | ||||
ShapeToString(args.src_shape).c_str()); | ShapeToString(args.src_shape).c_str()); | ||||
return ACL_ERROR_GE_PARAM_INVALID; | return ACL_ERROR_GE_PARAM_INVALID; | ||||
@@ -79,7 +79,8 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_fil | |||||
Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { | Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { | ||||
vector<int64_t> om_info; | vector<int64_t> om_info; | ||||
auto ge_model_weight = ge_model->GetWeight(); | auto ge_model_weight = ge_model->GetWeight(); | ||||
GELOGD("SaveSizeToModelDef weight_data_size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | |||||
GELOGD("SaveSizeToModelDef weight_data_size is %zu, ge_model_weight data is %p", ge_model_weight.GetSize(), | |||||
ge_model_weight.GetData()); | |||||
om_info.push_back(ge_model_weight.GetSize()); | om_info.push_back(ge_model_weight.GetSize()); | ||||
TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | ||||
@@ -284,7 +285,7 @@ Status ModelHelper::SaveAllModelPartiton(std::shared_ptr<OmFileSaveHelper>& om_f | |||||
if (SaveModelWeights(om_file_save_helper, ge_model, model_index) != SUCCESS) { | if (SaveModelWeights(om_file_save_helper, ge_model, model_index) != SUCCESS) { | ||||
GELOGE(FAILED, "[Save][ModelWeights]Failed, model %s, model index %zu", | GELOGE(FAILED, "[Save][ModelWeights]Failed, model %s, model index %zu", | ||||
ge_model->GetName().c_str(), model_index); | ge_model->GetName().c_str(), model_index); | ||||
REPORT_CALL_ERROR("E19999","ModelHelper save mode weights failed, model %s, model index %zu", | |||||
REPORT_CALL_ERROR("E19999", "ModelHelper save mode weights failed, model %s, model index %zu", | |||||
ge_model->GetName().c_str(), model_index); | ge_model->GetName().c_str(), model_index); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -441,7 +442,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRoo | |||||
GELOGE(INTERNAL_ERROR, "[Save][AllModelPartition]Failed, model name %s, cur_index %zu", | GELOGE(INTERNAL_ERROR, "[Save][AllModelPartition]Failed, model name %s, cur_index %zu", | ||||
model_name.c_str(), cur_index); | model_name.c_str(), cur_index); | ||||
REPORT_CALL_ERROR("E19999", "Save all model %s partition failed, cur_index %zu", | REPORT_CALL_ERROR("E19999", "Save all model %s partition failed, cur_index %zu", | ||||
model_name.c_str(), cur_index); | |||||
model_name.c_str(), cur_index); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
} | } | ||||
@@ -459,7 +460,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRoo | |||||
GELOGE(FAILED, "[Save][Model]OmFileSaveHelper save model eturn fail, output_file %s", | GELOGE(FAILED, "[Save][Model]OmFileSaveHelper save model eturn fail, output_file %s", | ||||
output_file.c_str()); | output_file.c_str()); | ||||
REPORT_CALL_ERROR("E19999", "OmFileSaveHelper save model return fail, output_file %s", | REPORT_CALL_ERROR("E19999", "OmFileSaveHelper save model return fail, output_file %s", | ||||
output_file.c_str()); | |||||
output_file.c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -601,7 +602,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { | ||||
if (model_data.model_data == nullptr || model_data.model_len == 0) { | if (model_data.model_data == nullptr || model_data.model_len == 0) { | ||||
GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "[Load][RootModel] " | GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "[Load][RootModel] " | ||||
"Model_data is nullptr or model_data_size is 0"); | |||||
"Model_data is nullptr or model data is empty."); | |||||
REPORT_INNER_ERROR("E19999", "Load root model failed, model_data is nullptr or its size is 0"); | REPORT_INNER_ERROR("E19999", "Load root model failed, model_data is nullptr or its size is 0"); | ||||
return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | ||||
} | } | ||||
@@ -628,7 +629,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootMod | |||||
//model verison 1.0 file header does not have model_num member | //model verison 1.0 file header does not have model_num member | ||||
is_unknown_shape_model_ = file_header_->version >= ge::MODEL_VERSION && | is_unknown_shape_model_ = file_header_->version >= ge::MODEL_VERSION && | ||||
file_header_->model_num > kStatiOmFileModelNum; | file_header_->model_num > kStatiOmFileModelNum; | ||||
GELOGD("cur om model is ge root model or no %d, model version %u", is_unknown_shape_model_, file_header_->version); | |||||
GELOGD("Cur om model is ge root model or no %d, model version %u", is_unknown_shape_model_, file_header_->version); | |||||
OmFileLoadHelper om_load_helper; | OmFileLoadHelper om_load_helper; | ||||
if (is_unknown_shape_model_) { | if (is_unknown_shape_model_) { | ||||
@@ -650,7 +651,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootMod | |||||
GELOGE(status, "[Generate][GERootModel]Failed"); | GELOGE(status, "[Generate][GERootModel]Failed"); | ||||
return status; | return status; | ||||
} | } | ||||
GELOGD("in ModelHelper::LoadRootModel, is_assign_model_ is setted to true!"); | |||||
GELOGD("In ModelHelper::LoadRootModel, is_assign_model_ is setted to true!"); | |||||
is_assign_model_ = true; | is_assign_model_ = true; | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -790,7 +791,7 @@ Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { | |||||
if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) { | if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) { | ||||
GELOGE(FAILED, "[Get][ModelWeightPartition]Failed, GetWeight size:%u", partition.size); | GELOGE(FAILED, "[Get][ModelWeightPartition]Failed, GetWeight size:%u", partition.size); | ||||
REPORT_CALL_ERROR("E19999", "[Get][ModelPartition]Failed, GetWeight size:%u", | REPORT_CALL_ERROR("E19999", "[Get][ModelPartition]Failed, GetWeight size:%u", | ||||
partition.size); | |||||
partition.size); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
ge::Buffer weight = ge::Buffer::CopyFrom(partition.data, partition.size); | ge::Buffer weight = ge::Buffer::CopyFrom(partition.data, partition.size); | ||||
@@ -805,7 +806,7 @@ Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper, GeModelPtr &cu | |||||
if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition, mode_index) != SUCCESS) { | if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition, mode_index) != SUCCESS) { | ||||
GELOGE(FAILED, "[Get][ModelPartition]Failed, GetWeight size:%u", partition.size); | GELOGE(FAILED, "[Get][ModelPartition]Failed, GetWeight size:%u", partition.size); | ||||
REPORT_CALL_ERROR("E19999", "[Get][ModelPartition]Failed, GetWeight size:%u", | REPORT_CALL_ERROR("E19999", "[Get][ModelPartition]Failed, GetWeight size:%u", | ||||
partition.size); | |||||
partition.size); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
ge::Buffer weight = ge::Buffer::CopyFrom(partition.data, partition.size); | ge::Buffer weight = ge::Buffer::CopyFrom(partition.data, partition.size); | ||||
@@ -444,17 +444,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status | |||||
OpUtils::GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType type, std::vector<int64_t> &dims) { | OpUtils::GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType type, std::vector<int64_t> &dims) { | ||||
if (tensor == nullptr) { | if (tensor == nullptr) { | ||||
GELOGE(PARAM_INVALID, "[Check][Param]Input tensor is nullptr"); | GELOGE(PARAM_INVALID, "[Check][Param]Input tensor is nullptr"); | ||||
REPORT_INNER_ERROR("E19999","Input tensor is nullptr"); | |||||
REPORT_INNER_ERROR("E19999", "Input tensor is nullptr"); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
// If the tensor data is a vector, the shape dimension must be 1 | // If the tensor data is a vector, the shape dimension must be 1 | ||||
if (tensor->GetTensorDesc().GetShape().GetDims().size() > 1) { | if (tensor->GetTensorDesc().GetShape().GetDims().size() > 1) { | ||||
GELOGE(PARAM_INVALID, "[Check][Param]The dimension of the input tensor shape " | |||||
"cannot be more than 1, it is %zu", | |||||
GELOGE(PARAM_INVALID, "[Check][Param]The dimension of the input tensor shape cannot be more than 1, it is %zu", | |||||
tensor->GetTensorDesc().GetShape().GetDims().size()); | tensor->GetTensorDesc().GetShape().GetDims().size()); | ||||
REPORT_CALL_ERROR("E19999", "The dimension of the input tensor shape %zu invalid, " | |||||
"more than 1", tensor->GetTensorDesc().GetShape().GetDims().size()); | |||||
REPORT_CALL_ERROR("E19999", "The dimension of the input tensor shape %zu invalid, more than 1", | |||||
tensor->GetTensorDesc().GetShape().GetDims().size()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -473,8 +472,8 @@ OpUtils::GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType ty | |||||
dims.push_back(shape_data[i]); | dims.push_back(shape_data[i]); | ||||
} | } | ||||
} else { | } else { | ||||
GELOGE(PARAM_INVALID, "[Check][DataType]Invalid, type only can be DT_INT32 or DT_INT64, " | |||||
"type is %s", TypeUtils::DataTypeToSerialString(type).c_str()); | |||||
GELOGE(PARAM_INVALID, "[Check][DataType]Invalid, type only can be DT_INT32 or DT_INT64, type is %s", | |||||
TypeUtils::DataTypeToSerialString(type).c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Data type %s check invalid, only can be DT_INT32 or DT_INT64", | REPORT_INNER_ERROR("E19999", "Data type %s check invalid, only can be DT_INT32 or DT_INT64", | ||||
TypeUtils::DataTypeToSerialString(type).c_str()); | TypeUtils::DataTypeToSerialString(type).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -304,7 +304,7 @@ std::string DNNEngineManager::GetHostCpuEngineName(const std::vector<OpInfo> &op | |||||
GELOGE(FAILED, "[Get][HostCpuEngineName]Failed, HostCpuEngine not support [%s, %s]", | GELOGE(FAILED, "[Get][HostCpuEngineName]Failed, HostCpuEngine not support [%s, %s]", | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
REPORT_INNER_ERROR("E19999", "Get HostCpuEngineName failed, HostCpuEngine not support [%s, %s]", | REPORT_INNER_ERROR("E19999", "Get HostCpuEngineName failed, HostCpuEngine not support [%s, %s]", | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
return ""; | return ""; | ||||
} | } | ||||
@@ -436,7 +436,7 @@ Status DNNEngineManager::ParserEngineMessage(const json engines_json, const std: | |||||
GELOGE(FAILED, "[Check][Param]There are the same engine %s message in the json file", | GELOGE(FAILED, "[Check][Param]There are the same engine %s message in the json file", | ||||
engine_id.c_str()); | engine_id.c_str()); | ||||
REPORT_INNER_ERROR("E19999", "There are the same engine %s message in the json file", | REPORT_INNER_ERROR("E19999", "There are the same engine %s message in the json file", | ||||
engine_id.c_str()); | |||||
engine_id.c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
engines.emplace(engine_id, engine_conf_ptr); | engines.emplace(engine_id, engine_conf_ptr); | ||||
@@ -684,7 +684,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
bool is_allocated_first_input = is_continuous_input_allocated && (in_data_anchor->GetIdx() == 0); | bool is_allocated_first_input = is_continuous_input_allocated && (in_data_anchor->GetIdx() == 0); | ||||
if (is_allocated_first_input) { | if (is_allocated_first_input) { | ||||
std::map<int32_t, int32_t> out2ins; | std::map<int32_t, int32_t> out2ins; | ||||
GE_CHK_STATUS_RET(TryGetNodeRefIndexes(node, out2ins), "[Get][RefIndexes]fail for node: %s", node->GetName().c_str()); | |||||
GE_CHK_STATUS_RET(TryGetNodeRefIndexes(node, out2ins), "[Get][RefIndexes]fail for node: %s", | |||||
node->GetName().c_str()); | |||||
// output is beginning offset, set offset for input; only support this case now | // output is beginning offset, set offset for input; only support this case now | ||||
if ((out2ins.size() == 1) && (out2ins.begin()->second == 0) && (reverse_refresh)) { | if ((out2ins.size() == 1) && (out2ins.begin()->second == 0) && (reverse_refresh)) { | ||||
auto peer_output_offset = output_list.at(peer_out_data_anchor->GetIdx()); | auto peer_output_offset = output_list.at(peer_out_data_anchor->GetIdx()); | ||||
@@ -246,7 +246,8 @@ Status ModelBuilder::SetInputOutputDesc() { | |||||
} | } | ||||
// if user set input node format ND, the expected node for data and netoutput format is ND in | // if user set input node format ND, the expected node for data and netoutput format is ND in | ||||
// final graph. | // final graph. | ||||
if ((compute_graph_->GetParentGraph() == nullptr) && (GetLocalOmgContext().format == domi::DOMI_TENSOR_ND) && (!node_op_desc->HasAttr("_is_single_op")) && | |||||
if ((compute_graph_->GetParentGraph() == nullptr) && (GetLocalOmgContext().format == domi::DOMI_TENSOR_ND) && | |||||
(!node_op_desc->HasAttr("_is_single_op")) && | |||||
((node_op_desc->GetType() == DATA_TYPE) || (node_op_desc->GetType() == NETOUTPUT))) { | ((node_op_desc->GetType() == DATA_TYPE) || (node_op_desc->GetType() == NETOUTPUT))) { | ||||
auto inputDescsPtr = node_op_desc->GetAllInputsDescPtr(); | auto inputDescsPtr = node_op_desc->GetAllInputsDescPtr(); | ||||
auto outputDescsPtr = node_op_desc->GetAllOutputsDescPtr(); | auto outputDescsPtr = node_op_desc->GetAllOutputsDescPtr(); | ||||
@@ -193,23 +193,29 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node) { | |||||
/// | /// | ||||
/// @brief set op next_iteration name | /// @brief set op next_iteration name | ||||
/// @param [in] node | |||||
/// @param [in] next | |||||
/// @param [in] Merge Node | |||||
/// @param [in] NextIteration Node | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status SetNextIteration(const ge::NodePtr &node, const std::string &next) { | |||||
Status SetNextIteration(const NodePtr &node, const NodePtr &next) { | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
OpDescPtr tmp_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(tmp_desc); | |||||
GE_CHECK_NOTNULL(next); | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
GE_CHECK_NOTNULL(next->GetOpDesc()); | |||||
if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_NEXT_ITERATION, next)) { | |||||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||||
node->GetName().c_str(), node->GetType().c_str()); | |||||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||||
node->GetName().c_str(), node->GetType().c_str()); | |||||
return FAILED; | |||||
} | |||||
const auto SetIterationName = [](const OpDescPtr &op_desc, const std::string &name) { | |||||
if (!AttrUtils::SetStr(op_desc, ATTR_NAME_NEXT_ITERATION, name)) { | |||||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
}; | |||||
GE_CHK_STATUS_RET_NOLOG(SetIterationName(node->GetOpDesc(), next->GetName())); | |||||
GE_CHK_STATUS_RET_NOLOG(SetIterationName(next->GetOpDesc(), node->GetName())); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -96,11 +96,11 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node); | |||||
/// | /// | ||||
/// @brief set op next_iteration name | /// @brief set op next_iteration name | ||||
/// @param [in] node | |||||
/// @param [in] next | |||||
/// @param [in] Merge Node | |||||
/// @param [in] NextIteration Node | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status SetNextIteration(const ge::NodePtr &node, const std::string &next); | |||||
Status SetNextIteration(const NodePtr &node, const NodePtr &next); | |||||
/// | /// | ||||
/// @brief Align the memory | /// @brief Align the memory | ||||
@@ -704,7 +704,7 @@ Status GraphExecutor::GetCurShape(const uint32_t model_id, std::vector<int64_t> | |||||
} | } | ||||
Status GraphExecutor::GetOpAttr(uint32_t model_id, const std::string &op_name, const std::string &attr_name, | Status GraphExecutor::GetOpAttr(uint32_t model_id, const std::string &op_name, const std::string &attr_name, | ||||
std::string &attr_value) { | |||||
std::string &attr_value) { | |||||
auto model_manager = ge::ModelManager::GetInstance(); | auto model_manager = ge::ModelManager::GetInstance(); | ||||
GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
Status ret = model_manager->GetOpAttr(model_id, op_name, attr_name, attr_value); | Status ret = model_manager->GetOpAttr(model_id, op_name, attr_name, attr_value); | ||||
@@ -886,6 +886,7 @@ Status GraphManager::PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, | |||||
GM_RUN_AND_DUMP_PERF("OptimizeSwitchOp", stages.preparer.SwitchOpOptimize, compute_graph); | GM_RUN_AND_DUMP_PERF("OptimizeSwitchOp", stages.preparer.SwitchOpOptimize, compute_graph); | ||||
} | } | ||||
GM_RUN_AND_DUMP_PERF("Optimize1", OptimizeStage1, compute_graph); | GM_RUN_AND_DUMP_PERF("Optimize1", OptimizeStage1, compute_graph); | ||||
GM_RUN_AND_DUMP_PERF("OptimizeAfterStage1", stages.optimizer.OptimizeAfterStage1, compute_graph); | |||||
GM_RUN_AND_DUMP_PERF("InferShape2", compute_graph->InferShapeInNeed); | GM_RUN_AND_DUMP_PERF("InferShape2", compute_graph->InferShapeInNeed); | ||||
PassManager graph_pass; | PassManager graph_pass; | ||||
@@ -3118,7 +3119,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
GraphNodePtr graph_node = nullptr; | GraphNodePtr graph_node = nullptr; | ||||
Status ret = graph_manager->GetGraphNode(args.graph_id, graph_node); | Status ret = graph_manager->GetGraphNode(args.graph_id, graph_node); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
ReturnError(graph_manager, args.callback, GE_GRAPH_ALREADY_RUNNING, | |||||
ReturnError(graph_manager, args.callback, GE_GRAPH_GRAPH_NODE_NULL, | |||||
"[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); | "[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); | ||||
return; | return; | ||||
} | } | ||||
@@ -3143,7 +3144,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
graph_node->Lock(); | graph_node->Lock(); | ||||
if (graph_node->GetRunFlag()) { | if (graph_node->GetRunFlag()) { | ||||
ReturnError(graph_manager, args.callback, GE_GRAPH_GRAPH_NODE_NULL, | |||||
ReturnError(graph_manager, args.callback, GE_GRAPH_ALREADY_RUNNING, | |||||
"[RunGraph] graph already running, graph id=" + std::to_string(args.graph_id)); | "[RunGraph] graph already running, graph id=" + std::to_string(args.graph_id)); | ||||
graph_node->Unlock(); | graph_node->Unlock(); | ||||
return; | return; | ||||
@@ -489,7 +489,7 @@ Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { | |||||
mem_resource = MemResource::BuildMemResourceFromType(memory_type); | mem_resource = MemResource::BuildMemResourceFromType(memory_type); | ||||
if (mem_resource == nullptr) { | if (mem_resource == nullptr) { | ||||
REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", | REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", | ||||
memory_type, session_id_); | |||||
memory_type, session_id_); | |||||
GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu", | GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu", | ||||
memory_type, session_id_); | memory_type, session_id_); | ||||
return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
@@ -275,7 +275,8 @@ Status HcomOmeUtil::GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, Hccl | |||||
"check invalid", ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), | "check invalid", ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | ||||
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in Op:%s(%s), horovod_op_type value:%ld is not support now", | GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in Op:%s(%s), horovod_op_type value:%ld is not support now", | ||||
ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | |||||
ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||||
horovod_op_type); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
op_type = iter->second; | op_type = iter->second; | ||||
@@ -155,7 +155,7 @@ Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | |||||
} | } | ||||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | ||||
GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", | |||||
GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %zu.", | |||||
graph_optimizer.size()); | graph_optimizer.size()); | ||||
string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | ||||
GELOGD("[OptimizeOriginalGraph]: engine type will exclude: %s", exclude_core_Type.c_str()); | GELOGD("[OptimizeOriginalGraph]: engine type will exclude: %s", exclude_core_Type.c_str()); | ||||
@@ -194,7 +194,7 @@ Status GraphOptimize::OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_ | |||||
} | } | ||||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | ||||
GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", | |||||
GELOGI("optimize by opskernel in judging insert phase. num of graph_optimizer is %zu.", | |||||
graph_optimizer.size()); | graph_optimizer.size()); | ||||
string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | ||||
if (graph_optimizer.size() != 0) { | if (graph_optimizer.size() != 0) { | ||||
@@ -294,6 +294,46 @@ Status GraphOptimize::OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_gr | |||||
return ret; | return ret; | ||||
} | } | ||||
Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) { | |||||
GE_CHECK_NOTNULL(compute_graph); | |||||
GELOGD("OptimizeAfterStage1 in"); | |||||
if (GetContext().GetHostExecFlag()) { | |||||
// graph exec on host, no need OptimizeAfterStage1 | |||||
return SUCCESS; | |||||
} | |||||
Status ret = SUCCESS; | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
REPORT_INNER_ERROR("E19999", "Gelib not init before, check invalid"); | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "OptimizeAfterStage1 failed."); | |||||
return GE_CLI_GE_NOT_INITIALIZED; | |||||
} | |||||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | |||||
GELOGI("Optimize by ops kernel in after stage1 phase, num of graph_optimizer is %zu.", graph_optimizer.size()); | |||||
string exclude_core_type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | |||||
if (graph_optimizer.size() != 0) { | |||||
for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { | |||||
if (iter->first == exclude_core_type) { | |||||
GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str()); | |||||
continue; | |||||
} | |||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str()); | |||||
ret = (iter->second)->OptimizeAfterStage1(*compute_graph); | |||||
#endif | |||||
if (ret != SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, " | |||||
"graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str()); | |||||
GELOGE(ret, "[OptimizeAfterStage1]: graph optimize failed, ret:%d.", ret); | |||||
return ret; | |||||
} | |||||
} | |||||
} | |||||
return ret; | |||||
} | |||||
Status GraphOptimize::SetOptions(const ge::GraphManagerOptions &options) { | Status GraphOptimize::SetOptions(const ge::GraphManagerOptions &options) { | ||||
if (options.framework_type >= static_cast<int32_t>(domi::FrameworkType::FRAMEWORK_RESERVED)) { | if (options.framework_type >= static_cast<int32_t>(domi::FrameworkType::FRAMEWORK_RESERVED)) { | ||||
REPORT_INNER_ERROR("E19999", "Param framework_type:%d in option check invalid", | REPORT_INNER_ERROR("E19999", "Param framework_type:%d in option check invalid", | ||||
@@ -58,6 +58,9 @@ class GraphOptimize { | |||||
// for rts optimize before build to add attr and insert memcpy op | // for rts optimize before build to add attr and insert memcpy op | ||||
Status OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_graph); | Status OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_graph); | ||||
// optimize whole graph, using after stage1 | |||||
Status OptimizeAfterStage1(ComputeGraphPtr &graph); | |||||
// set options | // set options | ||||
Status SetOptions(const GraphManagerOptions &options); | Status SetOptions(const GraphManagerOptions &options); | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
* You may obtain a copy of the License at | * You may obtain a copy of the License at | ||||
@@ -22,6 +22,7 @@ | |||||
#include "graph/optimize/graph_optimize.h" | #include "graph/optimize/graph_optimize.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | |||||
namespace { | namespace { | ||||
using namespace ge; | using namespace ge; | ||||
@@ -32,12 +33,14 @@ const int kCaseReadOnly = 0; | |||||
const int kCaseScopeWriteable = 2; | const int kCaseScopeWriteable = 2; | ||||
const int kCaseWriteable = 3; | const int kCaseWriteable = 3; | ||||
const int kCaseInvalidRWType = 5; | const int kCaseInvalidRWType = 5; | ||||
// attr _input_mutable = true means node will modify its input in runtime | |||||
const char *const kModifyInput = "_input_mutable"; | |||||
// rw type of input. | // rw type of input. | ||||
enum class InputRWType { | enum class InputRWType { | ||||
kReadOnly, // Normal op input only read | kReadOnly, // Normal op input only read | ||||
kWriteable, // Op like Assign/ApplyMomentum | kWriteable, // Op like Assign/ApplyMomentum | ||||
kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput | |||||
kScopeWriteable, // Op like hcom_allreduce/while, it will modify input ,but not expect take effect on pre ouput | |||||
kInvalidRWType | kInvalidRWType | ||||
}; | }; | ||||
// rw type of output | // rw type of output | ||||
@@ -154,7 +157,7 @@ bool IsSubgraphOutputNode(const NodePtr &node) { | |||||
return true; | return true; | ||||
} | } | ||||
NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||||
NodePtr AddIdentityToGraph(const Node &src_node, int out_anchor_idx) { | |||||
if (src_node.GetOpDesc() == nullptr) { | if (src_node.GetOpDesc() == nullptr) { | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -162,30 +165,19 @@ NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||||
auto next_num = identity_num.fetch_add(1); | auto next_num = identity_num.fetch_add(1); | ||||
// 1. create new identity op desc | // 1. create new identity op desc | ||||
string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); | string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); | ||||
auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY); | |||||
if (identity_opdesc == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str()); | |||||
return nullptr; | |||||
} | |||||
OpDescBuilder op_desc_builder(identity_name, IDENTITY); | |||||
auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); | auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); | ||||
// 2. add input_desc & output_desc for new identity | |||||
Status ret = identity_opdesc->AddInputDesc("x", data_desc); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str()); | |||||
return nullptr; | |||||
} | |||||
ret = identity_opdesc->AddOutputDesc("y", data_desc); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str()); | |||||
return nullptr; | |||||
} | |||||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc) | |||||
.AddOutput("y", data_desc) | |||||
.Build(); | |||||
GELOGI("Insert new Identity node %s.", identity_name.c_str()); | GELOGI("Insert new Identity node %s.", identity_name.c_str()); | ||||
auto graph = src_node.GetOwnerComputeGraph(); | auto graph = src_node.GetOwnerComputeGraph(); | ||||
if (graph == nullptr) { | if (graph == nullptr) { | ||||
GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); | GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return graph->AddNode(identity_opdesc); | |||||
return graph->AddNode(identity_op_desc); | |||||
} | } | ||||
OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { | OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { | ||||
@@ -274,8 +266,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { | |||||
// single node without sub graph | // single node without sub graph | ||||
return GetSingleNodeInputRWTypeByIndex(node, index); | return GetSingleNodeInputRWTypeByIndex(node, index); | ||||
} else { | } else { | ||||
// node with sub graph | |||||
std::set<int> node_rw_type_set; | |||||
auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | ||||
// get all input data node in subgraph | // get all input data node in subgraph | ||||
std::set<int> anchor_rw_type_set; | std::set<int> anchor_rw_type_set; | ||||
@@ -345,12 +335,24 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { | |||||
auto parent_node = sub_graph->GetParentNode(); | auto parent_node = sub_graph->GetParentNode(); | ||||
if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { | if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { | ||||
// insert identity | // insert identity | ||||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(identity_node); | GE_CHECK_NOTNULL(identity_node); | ||||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Fail to insert identity"); | |||||
return ret; | |||||
if (GraphUtils::InsertNodeAfter(pre_out_anchor, {in_data_anchor}, identity_node) != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||||
identity_node->GetName().c_str(), | |||||
identity_node->GetType().c_str(), | |||||
pre_node->GetName().c_str(), | |||||
pre_node->GetType().c_str(), | |||||
node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
GELOGE(FAILED, "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||||
identity_node->GetName().c_str(), | |||||
identity_node->GetType().c_str(), | |||||
pre_node->GetName().c_str(), | |||||
pre_node->GetType().c_str(), | |||||
node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
return FAILED; | |||||
} | } | ||||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | ||||
pre_node->GetName().c_str(), node->GetName().c_str()); | pre_node->GetName().c_str(), node->GetName().c_str()); | ||||
@@ -505,34 +507,24 @@ Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const I | |||||
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | ||||
GE_CHECK_NOTNULL(peer_in_data_node); | GE_CHECK_NOTNULL(peer_in_data_node); | ||||
auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); | auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); | ||||
auto ret = out_data_anchor->Unlink(peer_in_data_anchor); | |||||
auto old_identity = out_data_anchor->GetOwnerNode(); | auto old_identity = out_data_anchor->GetOwnerNode(); | ||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(), | |||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { | if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { | ||||
auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); | |||||
auto new_identity = AddIdentityToGraph(*pre_node, pre_out_data_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(new_identity); | GE_CHECK_NOTNULL(new_identity); | ||||
if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS | |||||
|| GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", | |||||
pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
// 2. copy in-control-edge from dst to Identity | |||||
if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(), | |||||
new_identity->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, new_identity, kIdentityAnchorIndex, | |||||
kIdentityAnchorIndex); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to insert Identity %s before %s %dth input.", | |||||
new_identity->GetName().c_str(), | |||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
peer_in_data_anchor->GetIdx()); | |||||
return ret; | |||||
} | } | ||||
GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), | GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), | ||||
InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | ||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | ||||
} else { | } else { | ||||
(void) out_data_anchor->Unlink(peer_in_data_anchor); | |||||
// copy control edge to pre and peer node | // copy control edge to pre and peer node | ||||
if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS | if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS | ||||
|| GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { | || GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { | ||||
@@ -613,16 +605,14 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||||
GELOGD("No need insert Identity."); | GELOGD("No need insert Identity."); | ||||
continue; | continue; | ||||
case INSERT_IDENTITY: | case INSERT_IDENTITY: | ||||
auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx()); | |||||
if (identity_node == nullptr) { | |||||
GELOGE(FAILED, "Create identity node failed."); | |||||
return FAILED; | |||||
} | |||||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(), | |||||
peer_in_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
auto identity_node = AddIdentityToGraph(*node, out_data_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, identity_node, kIdentityAnchorIndex, | |||||
kIdentityAnchorIndex); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Fail to insert %s before %s %dth input.", identity_node->GetName().c_str(), | |||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), peer_in_data_anchor->GetIdx()); | |||||
return ret; | |||||
} | } | ||||
GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), | GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), | ||||
peer_in_node->GetName().c_str()); | peer_in_node->GetName().c_str()); | ||||
@@ -633,28 +623,35 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | ||||
for (const auto &node : compute_graph->GetDirectNode()) { | |||||
if (node->GetType() == HCOMALLREDUCE) { | |||||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL(pre_out_anchor); | |||||
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||||
pre_out_anchor_set.emplace(pre_out_anchor); | |||||
continue; | |||||
} | |||||
// need insert identity | |||||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
for (const auto &node : compute_graph->GetDirectNode()) { | |||||
bool mutable_input_flag = false; | |||||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); | |||||
if (!mutable_input_flag) { | |||||
continue; | |||||
} | |||||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL(pre_out_anchor); | |||||
if (pre_out_anchor_set.insert(pre_out_anchor).second) { | |||||
continue; | |||||
} | |||||
// need insert identity | |||||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
auto ret = | |||||
GraphUtils::InsertNodeBefore(in_data_anchor, identity_node, kIdentityAnchorIndex, kIdentityAnchorIndex); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to insert node %s before %s %dth input.", identity_node->GetName().c_str(), | |||||
node->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
return ret; | |||||
} | |||||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | } | ||||
} // namespace | } // namespace | ||||
@@ -387,6 +387,9 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() { | |||||
if (!in_cluster->IsUnknownShape()) { | if (!in_cluster->IsUnknownShape()) { | ||||
continue; | continue; | ||||
} | } | ||||
if (!cluster->IsAdjoinNodes(in_cluster)) { | |||||
continue; | |||||
} | |||||
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | ||||
GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | ||||
ToString(merged_clusters).c_str()); | ToString(merged_clusters).c_str()); | ||||
@@ -80,6 +80,10 @@ class DynamicShapePartitioner { | |||||
Status BuildPartitionSubgraph(); | Status BuildPartitionSubgraph(); | ||||
// Clear resource and break circular dependency | // Clear resource and break circular dependency | ||||
void Clear(); | void Clear(); | ||||
bool IsAdjoinNodes(const std::shared_ptr<Cluster> &other) const { | |||||
const auto &out_clusters = other->out_clusters_; | |||||
return std::find(out_clusters.begin(), out_clusters.end(), shared_from_this()) != out_clusters.end(); | |||||
} | |||||
private: | private: | ||||
static thread_local size_t unique_id_; | static thread_local size_t unique_id_; | ||||
@@ -451,7 +451,7 @@ Status AtomicAddrCleanPass::CompileUnknownGraphOp(const vector<NodePtr> &atomic_ | |||||
GE_TIMESTAMP_ADD(UnknownGraphCompileOp); | GE_TIMESTAMP_ADD(UnknownGraphCompileOp); | ||||
if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Call CompileOp failed, kernel_lib_name:%s, ret:%d", | REPORT_CALL_ERROR("E19999", "Call CompileOp failed, kernel_lib_name:%s, ret:%d", | ||||
kernel_lib_name.c_str(), ret); | |||||
kernel_lib_name.c_str(), ret); | |||||
GELOGE(ret, "Compile atomic op failed, kernel lib name is %s", kernel_lib_name.c_str()); | GELOGE(ret, "Compile atomic op failed, kernel lib name is %s", kernel_lib_name.c_str()); | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -29,7 +29,8 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | |||||
std::map<NodePtr, NodePtr> branch_head_nodes; | std::map<NodePtr, NodePtr> branch_head_nodes; | ||||
FindNodes(graph, need_label_nodes, enter_nodes, branch_head_nodes); | FindNodes(graph, need_label_nodes, enter_nodes, branch_head_nodes); | ||||
for (const auto &node : need_label_nodes) { | for (const auto &node : need_label_nodes) { | ||||
GE_CHK_STATUS_RET(UpdateCondBranch(node, branch_head_nodes), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||||
GE_CHK_STATUS_RET(UpdateCondBranch(node, branch_head_nodes), "Update cond branch failed, start node:%s.", | |||||
node->GetName().c_str()); | |||||
} | } | ||||
GE_CHK_STATUS_RET(UpdateEnterNode(enter_nodes), "UpdateEnterNode failed."); | GE_CHK_STATUS_RET(UpdateEnterNode(enter_nodes), "UpdateEnterNode failed."); | ||||
@@ -62,7 +62,7 @@ Status BitcastPass::CheckDstDataType(const OpDescPtr op_desc, ge::DataType &dst_ | |||||
if (!ge::AttrUtils::GetDataType(op_desc, kAttrNameType, dst_data_type)) { | if (!ge::AttrUtils::GetDataType(op_desc, kAttrNameType, dst_data_type)) { | ||||
REPORT_CALL_ERROR("E19999", "Get Attr:%s of op:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Get Attr:%s of op:%s(%s) failed", | ||||
kAttrNameType, op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
kAttrNameType, op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(PARAM_INVALID, "Node failed to get attribute type."); | GELOGE(PARAM_INVALID, "Node failed to get attribute type."); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -166,7 +166,7 @@ Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph | |||||
if (iter == subgraph_names_to_index.end()) { | if (iter == subgraph_names_to_index.end()) { | ||||
REPORT_INNER_ERROR("E19999", "subgraph name:%s not exist in SubgraphNameIndexes map of op:%s(%s), " | REPORT_INNER_ERROR("E19999", "subgraph name:%s not exist in SubgraphNameIndexes map of op:%s(%s), " | ||||
"check invalid", ATTR_NAME_WHILE_COND.c_str(), | "check invalid", ATTR_NAME_WHILE_COND.c_str(), | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(FAILED, "Get cond_graph index failed, while_node:%s.", node->GetName().c_str()); | GELOGE(FAILED, "Get cond_graph index failed, while_node:%s.", node->GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -65,13 +65,13 @@ Status CtrlEdgeTransferPass::Run(ge::ComputeGraphPtr graph) { | |||||
for (auto &in_control_node : n->GetInControlNodes()) { | for (auto &in_control_node : n->GetInControlNodes()) { | ||||
GE_CHECK_NOTNULL(in_control_node); | GE_CHECK_NOTNULL(in_control_node); | ||||
GE_CHK_GRAPH_STATUS_RET(ge::GraphUtils::RemoveEdge(in_control_node->GetOutControlAnchor(), | GE_CHK_GRAPH_STATUS_RET(ge::GraphUtils::RemoveEdge(in_control_node->GetOutControlAnchor(), | ||||
n->GetInControlAnchor()), "remove edge failed"); | |||||
n->GetInControlAnchor()), "remove edge failed"); | |||||
for (auto &out_node : n->GetOutNodes()) { | for (auto &out_node : n->GetOutNodes()) { | ||||
if (out_node == nullptr) { | if (out_node == nullptr) { | ||||
continue; | continue; | ||||
} | } | ||||
GE_CHK_GRAPH_STATUS_RET(ge::GraphUtils::AddEdge(in_control_node->GetOutControlAnchor(), | GE_CHK_GRAPH_STATUS_RET(ge::GraphUtils::AddEdge(in_control_node->GetOutControlAnchor(), | ||||
out_node->GetInControlAnchor()), "add edge failed."); | |||||
out_node->GetInControlAnchor()), "add edge failed."); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -24,11 +24,12 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | |||||
namespace { | namespace { | ||||
const int kAnchorNum = 0; | |||||
const int32_t kAnchorAssignRefIndex = 0; | const int32_t kAnchorAssignRefIndex = 0; | ||||
const int32_t kAnchorAssignValueIndex = 1; | const int32_t kAnchorAssignValueIndex = 1; | ||||
const int32_t kAnchorIdentityIndex = 0; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { | Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { | ||||
@@ -161,41 +162,23 @@ NodePtr HcclContinuousMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &grap | |||||
std::string node_name = pre_node->GetName() + "_" + IDENTITY; | std::string node_name = pre_node->GetName() + "_" + IDENTITY; | ||||
node_name = CheckDuplicateName(node_name); | node_name = CheckDuplicateName(node_name); | ||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||||
if (op_desc == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||||
return nullptr; | |||||
} | |||||
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||||
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||||
return nullptr; | |||||
} | |||||
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||||
OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||||
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||||
if (identity_op_desc == nullptr) { | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
// because history reason ,this pass can not do work after constant fold so mark it | // because history reason ,this pass can not do work after constant fold so mark it | ||||
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
NodePtr memcpy_node = graph->AddNode(op_desc); | |||||
if (memcpy_node == nullptr) { | |||||
NodePtr identity_node = graph->AddNode(identity_op_desc); | |||||
if (identity_node == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return memcpy_node; | |||||
return identity_node; | |||||
} | } | ||||
/// | /// | ||||
@@ -256,50 +239,24 @@ Status HcclContinuousMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &gra | |||||
Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, | Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, | ||||
const OutDataAnchorPtr &src_out_anchor, | const OutDataAnchorPtr &src_out_anchor, | ||||
const InDataAnchorPtr &hccl_in_anchor) { | const InDataAnchorPtr &hccl_in_anchor) { | ||||
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | ||||
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||||
GE_CHECK_NOTNULL(memcpy_node); | |||||
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||||
if (ret1 != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||||
hccl_in_anchor->GetIdx()); | |||||
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||||
GE_CHECK_NOTNULL(out_data_anchor_0); | |||||
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||||
if (ret1 != SUCCESS) { | |||||
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||||
if (ret != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | REPORT_CALL_ERROR("E19999", | ||||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||||
out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||||
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||||
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | ||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | ||||
hccl_in_anchor->GetIdx()); | hccl_in_anchor->GetIdx()); | ||||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||||
if (ret != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||||
kAnchorNum); | |||||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
memcpy_node->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||||
hccl_in_anchor->GetIdx()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -24,13 +24,15 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | |||||
namespace { | namespace { | ||||
const int32_t kAnchorSize = 1; | const int32_t kAnchorSize = 1; | ||||
const int kAnchorNum = 0; | |||||
const int32_t kAnchorAssignRefIndex = 0; | const int32_t kAnchorAssignRefIndex = 0; | ||||
const int32_t kAnchorAssignValueIndex = 1; | const int32_t kAnchorAssignValueIndex = 1; | ||||
const char *const kInputMutable = "_input_mutable"; | |||||
const int32_t kAnchorIdentityIndex = 0; | |||||
// attr _input_mutable = true means hccl node will modify its input in runtime | |||||
const char *const kModifyInput = "_input_mutable"; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | ||||
@@ -58,24 +60,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||||
// need to inset memcpy node between. | // need to inset memcpy node between. | ||||
// also works on situation that input is variable or const. | // also works on situation that input is variable or const. | ||||
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | ||||
auto op_desc = node->GetOpDesc(); | |||||
bool node_input_mutable = false; | bool node_input_mutable = false; | ||||
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { | |||||
return SUCCESS; | |||||
} | |||||
if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { | |||||
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); | |||||
if (!node_input_mutable) { | if (!node_input_mutable) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); | |||||
GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); | |||||
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | ||||
if (hccl_in_anchor == nullptr) { | if (hccl_in_anchor == nullptr) { | ||||
continue; | continue; | ||||
@@ -127,41 +118,23 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O | |||||
std::string node_name = pre_node->GetName() + "_" + IDENTITY; | std::string node_name = pre_node->GetName() + "_" + IDENTITY; | ||||
node_name = CheckDuplicateName(node_name); | node_name = CheckDuplicateName(node_name); | ||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||||
if (op_desc == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||||
return nullptr; | |||||
} | |||||
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||||
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed, name:x", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||||
return nullptr; | |||||
} | |||||
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed, name:y", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||||
OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||||
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||||
if (identity_op_desc == nullptr) { | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
// because history reason ,this pass can not do work after constant fold so mark it | // because history reason ,this pass can not do work after constant fold so mark it | ||||
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
NodePtr memcpy_node = graph->AddNode(op_desc); | |||||
if (memcpy_node == nullptr) { | |||||
NodePtr identity_node = graph->AddNode(identity_op_desc); | |||||
if (identity_node == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return memcpy_node; | |||||
return identity_node; | |||||
} | } | ||||
/// | /// | ||||
@@ -220,49 +193,24 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const | |||||
/// | /// | ||||
Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | ||||
const InDataAnchorPtr &hccl_in_anchor) { | const InDataAnchorPtr &hccl_in_anchor) { | ||||
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | ||||
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||||
GE_CHECK_NOTNULL(memcpy_node); | |||||
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||||
if (ret1 != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx()); | |||||
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||||
GE_CHECK_NOTNULL(out_data_anchor_0); | |||||
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||||
if (ret1 != SUCCESS) { | |||||
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||||
if (ret != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | REPORT_CALL_ERROR("E19999", | ||||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||||
out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||||
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||||
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | ||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | ||||
hccl_in_anchor->GetIdx()); | hccl_in_anchor->GetIdx()); | ||||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||||
if (ret != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||||
kAnchorNum); | |||||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
memcpy_node->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||||
hccl_in_anchor->GetIdx()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -340,13 +288,13 @@ Status HcclMemcpyPass::InsertAssignAfterBroadcastIfNeed(const ComputeGraphPtr &g | |||||
} | } | ||||
ret = assign_out_control_anchor->LinkTo(in_data_anchor->GetOwnerNode()->GetInControlAnchor()); | ret = assign_out_control_anchor->LinkTo(in_data_anchor->GetOwnerNode()->GetInControlAnchor()); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||||
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(), | |||||
assign_out_control_anchor->GetOwnerNode()->GetType().c_str(), assign_out_control_anchor->GetIdx(), | |||||
in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
in_data_anchor->GetOwnerNode()->GetType().c_str(), | |||||
in_data_anchor->GetIdx()); | |||||
REPORT_CALL_ERROR("E19999", "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||||
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(), | |||||
assign_out_control_anchor->GetOwnerNode()->GetType().c_str(), | |||||
assign_out_control_anchor->GetIdx(), | |||||
in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
in_data_anchor->GetOwnerNode()->GetType().c_str(), | |||||
in_data_anchor->GetIdx()); | |||||
GELOGE(INTERNAL_ERROR, "The op %s link control anchor %s fail.", | GELOGE(INTERNAL_ERROR, "The op %s link control anchor %s fail.", | ||||
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(), | assign_out_control_anchor->GetOwnerNode()->GetName().c_str(), | ||||
in_data_anchor->GetOwnerNode()->GetName().c_str()); | in_data_anchor->GetOwnerNode()->GetName().c_str()); | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -78,8 +78,6 @@ Status InplaceSupportCheckPass::Run(NodePtr &node) { | |||||
AddRePassNode(node); | AddRePassNode(node); | ||||
break; | break; | ||||
} | } | ||||
GELOGD("InplaceSupportCheckPass success"); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -29,7 +29,7 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) { | |||||
bool forced_unknown = false; | bool forced_unknown = false; | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
GE_CHK_GRAPH_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), | GE_CHK_GRAPH_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), | ||||
"Get node[%s] shape status failed!", node->GetName().c_str()); | |||||
"Get node[%s] shape status failed!", node->GetName().c_str()); | |||||
if (is_unknown_shape) { | if (is_unknown_shape) { | ||||
break; | break; | ||||
} | } | ||||
@@ -1020,7 +1020,7 @@ Status MultiBatchClonePass::SetShapeToData(const std::vector<int64_t> &shapes, c | |||||
if (!IsGetNextType(data)) { | if (!IsGetNextType(data)) { | ||||
if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { | if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Update input desc shape to op:%s(%s) failed, index:%u", | REPORT_CALL_ERROR("E19999", "Update input desc shape to op:%s(%s) failed, index:%u", | ||||
data->GetName().c_str(), data->GetType().c_str(), kDataInIndex); | |||||
data->GetName().c_str(), data->GetType().c_str(), kDataInIndex); | |||||
GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
@@ -759,7 +759,7 @@ Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string & | |||||
GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str()); | GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str()); | ||||
if (SetStreamLabel(cur_node, stream_label) != SUCCESS) { | if (SetStreamLabel(cur_node, stream_label) != SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed", | ||||
stream_label.c_str(), cur_node->GetName().c_str(), cur_node->GetType().c_str()); | |||||
stream_label.c_str(), cur_node->GetName().c_str(), cur_node->GetType().c_str()); | |||||
GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str()); | GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -73,7 +73,7 @@ Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node, | |||||
if (iter != targets_.end()) { | if (iter != targets_.end()) { | ||||
targets_.erase(iter); | targets_.erase(iter); | ||||
targets_.insert(src_node_ptr); | targets_.insert(src_node_ptr); | ||||
GELOGI("node [%s] is in user def targets, do not output result to user!", node->GetName().c_str()); | |||||
GELOGI("Node [%s] is in user def targets, do not output result to user!", node->GetName().c_str()); | |||||
} | } | ||||
is_include_special_node_ = true; | is_include_special_node_ = true; | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -105,7 +105,7 @@ Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, std::vecto | |||||
for (auto &ele : graph->GetGraphOutNodesInfo()) { | for (auto &ele : graph->GetGraphOutNodesInfo()) { | ||||
auto iter = targets_.find(ele.first); | auto iter = targets_.find(ele.first); | ||||
if (iter != targets_.end()) { | if (iter != targets_.end()) { | ||||
GELOGI("user set out node [%s] is found in user def targets, out node is prio!", ele.first->GetName().c_str()); | |||||
GELOGI("User set out node [%s] is found in user def targets, out node is prior!", ele.first->GetName().c_str()); | |||||
targets_.erase(iter); | targets_.erase(iter); | ||||
} | } | ||||
@@ -213,7 +213,7 @@ Status NetOutputPass::UpdateNetOutputDesc(const ge::NodePtr &net_output) { | |||||
std::vector<bool> is_input_const; | std::vector<bool> is_input_const; | ||||
for (const auto &in_anchor : net_output->GetAllInDataAnchors()) { | for (const auto &in_anchor : net_output->GetAllInDataAnchors()) { | ||||
GE_CHECK_NOTNULL(in_anchor); | GE_CHECK_NOTNULL(in_anchor); | ||||
uint32_t index = static_cast<uint32_t>(in_anchor->GetIdx()); | |||||
auto index = static_cast<uint32_t>(in_anchor->GetIdx()); | |||||
if (index >= net_output_desc->GetAllInputsDesc().size()) { | if (index >= net_output_desc->GetAllInputsDesc().size()) { | ||||
REPORT_INNER_ERROR("E19999", "Node:%s(%s) has in_anchor index:%u >= its input desc num:%zu, check invalid", | REPORT_INNER_ERROR("E19999", "Node:%s(%s) has in_anchor index:%u >= its input desc num:%zu, check invalid", | ||||
net_output_desc->GetName().c_str(), net_output_desc->GetType().c_str(), index, | net_output_desc->GetName().c_str(), net_output_desc->GetType().c_str(), index, | ||||
@@ -369,10 +369,9 @@ Status NetOutputPass::UnLinkDataAnchorOfNetoutput(const ge::ComputeGraphPtr &gra | |||||
if (!CheckNodeIsInOutputNodes(graph, node)) { | if (!CheckNodeIsInOutputNodes(graph, node)) { | ||||
ret = in_data_anchor->Unlink(peer_out_anchor); | ret = in_data_anchor->Unlink(peer_out_anchor); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||||
net_out_node->GetName().c_str(), net_out_node->GetType().c_str(), in_data_anchor->GetIdx(), | |||||
node->GetName().c_str(), node->GetType().c_str(), peer_out_anchor->GetIdx()); | |||||
REPORT_CALL_ERROR("E19999", "Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||||
net_out_node->GetName().c_str(), net_out_node->GetType().c_str(), in_data_anchor->GetIdx(), | |||||
node->GetName().c_str(), node->GetType().c_str(), peer_out_anchor->GetIdx()); | |||||
GELOGE(INTERNAL_ERROR, "Unlink peer_out_anchor fail!"); | GELOGE(INTERNAL_ERROR, "Unlink peer_out_anchor fail!"); | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -565,7 +564,7 @@ Status NetOutputPass::AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph, | |||||
GELOGI("[NETOUTPUT PASS] Add net output node succeed"); | GELOGI("[NETOUTPUT PASS] Add net output node succeed"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGI("[NETOUTPUT PASS] Output node size:%lu.", output_nodes_info.size()); | |||||
GELOGI("[NETOUTPUT PASS] Output node size:%zu.", output_nodes_info.size()); | |||||
if (output_nodes_info.empty()) { | if (output_nodes_info.empty()) { | ||||
// because retval node is contained by output_nodes_info, here means targets is non-empty | // because retval node is contained by output_nodes_info, here means targets is non-empty | ||||
output_node = graph->AddNode(net_output_desc); | output_node = graph->AddNode(net_output_desc); | ||||
@@ -354,7 +354,7 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & | |||||
merge_node->GetName().c_str()); | merge_node->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) { | |||||
if (SetNextIteration(merge_node, next_node) != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed", | ||||
next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str()); | next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str()); | ||||
GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); | ||||
@@ -170,7 +170,7 @@ Status PassUtils::SetOutNodeWeight(const OutDataAnchorPtr &out_data_anchor, cons | |||||
// restore control inputs to dynamically added constant ops, if any | // restore control inputs to dynamically added constant ops, if any | ||||
for (const auto &src_out_control_anchor : src_out_control_anchors) { | for (const auto &src_out_control_anchor : src_out_control_anchors) { | ||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(src_out_control_anchor, dynamic_const_node->GetInControlAnchor()), | GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(src_out_control_anchor, dynamic_const_node->GetInControlAnchor()), | ||||
"add edge failed"); | |||||
"add edge failed"); | |||||
} | } | ||||
} | } | ||||
@@ -51,7 +51,7 @@ std::vector<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesByIndex(const No | |||||
auto out_anchor = node->GetOutDataAnchor(index); | auto out_anchor = node->GetOutDataAnchor(index); | ||||
if (out_anchor == nullptr) { | if (out_anchor == nullptr) { | ||||
REPORT_INNER_ERROR("E19999", "Node:%s(%s) has no index:%d out data anchor, check invalid", | REPORT_INNER_ERROR("E19999", "Node:%s(%s) has no index:%d out data anchor, check invalid", | ||||
node->GetName().c_str(), node->GetType().c_str(), index); | |||||
node->GetName().c_str(), node->GetType().c_str(), index); | |||||
GELOGE(PARAM_INVALID, "Failed to get out data nodes of index %d from node %s, the anchor does not exists", index, | GELOGE(PARAM_INVALID, "Failed to get out data nodes of index %d from node %s, the anchor does not exists", index, | ||||
node->GetName().c_str()); | node->GetName().c_str()); | ||||
return {}; | return {}; | ||||
@@ -1077,9 +1077,9 @@ graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdge(const int index, | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str()); | peer_in_anchor->GetOwnerNode()->GetName().c_str()); | ||||
if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { | if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", | ||||
new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetType().c_str()); | |||||
new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetType().c_str()); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
} | } | ||||
@@ -1103,9 +1103,9 @@ graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdge(const int index, | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str()); | peer_in_anchor->GetOwnerNode()->GetName().c_str()); | ||||
if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { | if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", | ||||
new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetType().c_str()); | |||||
new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetType().c_str()); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
} | } | ||||
@@ -87,10 +87,10 @@ Status ByPassTransNode(NodePtr &trans_node, NodePtr &ref_node) { | |||||
ret = GraphUtils::AddEdge(prev_trans_node_out_anchor, ref_in_anchor); | ret = GraphUtils::AddEdge(prev_trans_node_out_anchor, ref_in_anchor); | ||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", | REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", | ||||
prev_trans_node_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
prev_trans_node_out_anchor->GetOwnerNode()->GetType().c_str(), | |||||
prev_trans_node_out_anchor->GetIdx(), | |||||
ref_node->GetName().c_str(), ref_node->GetType().c_str()); | |||||
prev_trans_node_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
prev_trans_node_out_anchor->GetOwnerNode()->GetType().c_str(), | |||||
prev_trans_node_out_anchor->GetIdx(), | |||||
ref_node->GetName().c_str(), ref_node->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, | GELOGE(INTERNAL_ERROR, | ||||
"Failed to add edge between ref node %s " | "Failed to add edge between ref node %s " | ||||
"and the prev node of trans node %s", | "and the prev node of trans node %s", | ||||
@@ -241,7 +241,7 @@ NodePtr CreateTransNode(const std::string &name, const std::string &node_type, c | |||||
ret = op_desc->AddInputDesc(shape_desc->GetOutputDesc(0)); | ret = op_desc->AddInputDesc(shape_desc->GetOutputDesc(0)); | ||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Add input desc into op:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Add input desc into op:%s(%s) failed", | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Failed to add the first input for reshape %s", name.c_str()); | GELOGE(INTERNAL_ERROR, "Failed to add the first input for reshape %s", name.c_str()); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -837,7 +837,7 @@ Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, No | |||||
old_shape = switchn_output->GetShape(); | old_shape = switchn_output->GetShape(); | ||||
if (ModifyFormatAndShapeForSingleTensor(switchn_output) != SUCCESS) { | if (ModifyFormatAndShapeForSingleTensor(switchn_output) != SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Modify format and shape of output:%u in op:%s(%s) failed", i, | REPORT_CALL_ERROR("E19999", "Modify format and shape of output:%u in op:%s(%s) failed", i, | ||||
switchn_op_desc->GetName().c_str(), switchn_op_desc->GetType().c_str()); | |||||
switchn_op_desc->GetName().c_str(), switchn_op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "modify format and shape failed"); | GELOGE(INTERNAL_ERROR, "modify format and shape failed"); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -1266,8 +1266,8 @@ Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index | |||||
auto ret = GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_index), merge->GetInDataAnchor(i)); | auto ret = GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_index), merge->GetInDataAnchor(i)); | ||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%zu) failed", | REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%zu) failed", | ||||
src_node->GetName().c_str(), src_node->GetType().c_str(), out_index, | |||||
merge->GetName().c_str(), merge->GetType().c_str(), i); | |||||
src_node->GetName().c_str(), src_node->GetType().c_str(), out_index, | |||||
merge->GetName().c_str(), merge->GetType().c_str(), i); | |||||
GELOGE(INTERNAL_ERROR, | GELOGE(INTERNAL_ERROR, | ||||
"Failed to add edge between copyed node %s(%d) to inserted merge node %s(%zu), error-code %u", | "Failed to add edge between copyed node %s(%d) to inserted merge node %s(%zu), error-code %u", | ||||
copyed_nodes[i]->GetName().c_str(), out_index, merge->GetName().c_str(), i, ret); | copyed_nodes[i]->GetName().c_str(), out_index, merge->GetName().c_str(), i, ret); | ||||
@@ -306,28 +306,15 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
return task_context_; | return task_context_; | ||||
} | } | ||||
void NodeState::ResetContext(int group) { | |||||
SetGroup(group); | |||||
if (loop_count_ == 0) { | |||||
++loop_count_; | |||||
return; | |||||
} | |||||
++loop_count_; | |||||
if (loop_count_ == UINT64_MAX) { | |||||
loop_count_ = 1; | |||||
} | |||||
void NodeState::ResetContext(uint64_t loop_count) { | |||||
loop_count_ = loop_count; | |||||
switch_index_ = -1; | switch_index_ = -1; | ||||
subgraph_context_->ResetContext(node_item_->node); | subgraph_context_->ResetContext(node_item_->node); | ||||
GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); | |||||
} | |||||
void NodeState::ResetSchedule() { | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | ||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ||||
GELOGD("[%s] set schedule for root nodes, data: %u, ctrl: %u", GetName().c_str(), data_scheduled_, ctrl_scheduled_); | |||||
GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", | |||||
GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_); | |||||
} | } | ||||
Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | ||||
@@ -335,14 +322,14 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||||
for (const auto &node : node_item_->data_send_) { | for (const auto &node : node_item_->data_send_) { | ||||
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | ||||
GE_CHECK_NOTNULL(dst_node_state); | GE_CHECK_NOTNULL(dst_node_state); | ||||
dst_node_state->SetDataSchedule(node_item_, ready); | |||||
dst_node_state->SetDataSchedule(*this, ready); | |||||
} | } | ||||
// Schedule ctrl output. | // Schedule ctrl output. | ||||
for (const auto &node : node_item_->ctrl_send_) { | for (const auto &node : node_item_->ctrl_send_) { | ||||
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | ||||
GE_CHECK_NOTNULL(dst_node_state); | GE_CHECK_NOTNULL(dst_node_state); | ||||
dst_node_state->SetCtrlSchedule(node_item_, ready); | |||||
dst_node_state->SetCtrlSchedule(*this, ready); | |||||
} | } | ||||
// Schedule switch group. | // Schedule switch group. | ||||
@@ -351,7 +338,7 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||||
for (const auto &node : node_item_->switch_groups_[switch_index_]) { | for (const auto &node : node_item_->switch_groups_[switch_index_]) { | ||||
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | ||||
GE_CHECK_NOTNULL(dst_node_state); | GE_CHECK_NOTNULL(dst_node_state); | ||||
dst_node_state->SetCtrlSchedule(node_item_, ready); | |||||
dst_node_state->SetCtrlSchedule(*this, ready); | |||||
} | } | ||||
} | } | ||||
@@ -359,36 +346,44 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||||
} | } | ||||
bool NodeState::IsScheduleReady() const { | bool NodeState::IsScheduleReady() const { | ||||
GELOGD("[%s] data[input: %zu, scheduled: %u], ctrl[input: %zu, scheduled: %u]", GetName().c_str(), | |||||
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||||
if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) { | |||||
return false; | |||||
} | |||||
GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", | |||||
GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
if (node_item_->IsMergeOp()) { | if (node_item_->IsMergeOp()) { | ||||
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { | |||||
return false; | |||||
} | |||||
return data_scheduled_ > 0; | return data_scheduled_ > 0; | ||||
} | } | ||||
if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) { | |||||
return false; | |||||
} | |||||
// Exit may feed loop times... | // Exit may feed loop times... | ||||
return data_scheduled_ >= node_item_->data_recv_.size(); | return data_scheduled_ >= node_item_->data_recv_.size(); | ||||
} | } | ||||
void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) { | |||||
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", | |||||
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||||
void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||||
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", | |||||
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
if (loop_count_ != node_state.loop_count_) { | |||||
ResetContext(node_state.loop_count_); | |||||
} | |||||
++data_scheduled_; | ++data_scheduled_; | ||||
if (node_item_->IsMergeOp()) { | if (node_item_->IsMergeOp()) { | ||||
const auto it = node_item_->data_recv_.find(node_item); | |||||
const auto it = node_item_->data_recv_.find(node_state.node_item_); | |||||
if (it != node_item_->data_recv_.end()) { | if (it != node_item_->data_recv_.end()) { | ||||
merge_index_ = it->second; | merge_index_ = it->second; | ||||
(void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second); | (void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second); | ||||
GELOGD("[%s] scheduled, [%s] set merge index: %d", node_item->node_name.c_str(), GetName().c_str(), it->second); | |||||
GELOGD("[%s] scheduled, [%s] set merge index: %d", node_state.GetName().c_str(), GetName().c_str(), it->second); | |||||
} else { | } else { | ||||
GELOGW("[%s] scheduled, [%s] not followed", node_item->node_name.c_str(), GetName().c_str()); | |||||
GELOGW("[%s] scheduled, [%s] not followed", node_state.GetName().c_str(), GetName().c_str()); | |||||
} | } | ||||
} | } | ||||
@@ -397,12 +392,15 @@ void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<v | |||||
} | } | ||||
} | } | ||||
void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) { | |||||
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", | |||||
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||||
void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||||
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", | |||||
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
if (loop_count_ != node_state.loop_count_) { | |||||
ResetContext(node_state.loop_count_); | |||||
} | |||||
++ctrl_scheduled_; | ++ctrl_scheduled_; | ||||
if (IsScheduleReady()) { | if (IsScheduleReady()) { | ||||
@@ -410,6 +408,21 @@ void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<v | |||||
} | } | ||||
} | } | ||||
void NodeState::RunLoopNext() { | |||||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
++loop_count_; | |||||
if (loop_count_ == UINT64_MAX) { | |||||
loop_count_ = 1; | |||||
} | |||||
} | |||||
void NodeState::RunLoopExit() { | |||||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
loop_count_ = 0; | |||||
} | |||||
void NodeState::SetScheduleFuture(std::future<Status> &&future) { | void NodeState::SetScheduleFuture(std::future<Status> &&future) { | ||||
schedule_future_ = std::move(future); | schedule_future_ = std::move(future); | ||||
} | } | ||||
@@ -112,9 +112,8 @@ struct NodeState { | |||||
return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | ||||
} | } | ||||
void ResetContext(int group); | |||||
void ResetSchedule(); | |||||
void RunLoopNext(); | |||||
void RunLoopExit(); | |||||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | ||||
@@ -166,8 +165,9 @@ struct NodeState { | |||||
private: | private: | ||||
bool IsScheduleReady() const; | bool IsScheduleReady() const; | ||||
void SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready); | |||||
void SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready); | |||||
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||||
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||||
void ResetContext(uint64_t loop_count); | |||||
const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
std::shared_ptr<NodeTask> kernel_task_ = nullptr; | std::shared_ptr<NodeTask> kernel_task_ = nullptr; | ||||
@@ -46,6 +46,10 @@ Status SubgraphContext::Init() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void SubgraphContext::SetGroup(int group) { | |||||
group_ = group; | |||||
} | |||||
void SubgraphContext::ResetContext(const NodePtr &node) { | void SubgraphContext::ResetContext(const NodePtr &node) { | ||||
node_done_manager_.Reset(node); | node_done_manager_.Reset(node); | ||||
} | } | ||||
@@ -84,7 +88,8 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||||
auto &node_state = node_states_[node_item]; | auto &node_state = node_states_[node_item]; | ||||
if (node_state == nullptr) { | if (node_state == nullptr) { | ||||
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | ||||
node_state = std::move(std::unique_ptr<NodeState>(new(std::nothrow)NodeState(*node_item, this))); | |||||
node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||||
node_state->SetGroup(group_); | |||||
(void)guard; | (void)guard; | ||||
} | } | ||||
GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | ||||
@@ -34,6 +34,7 @@ class SubgraphContext { | |||||
~SubgraphContext(); | ~SubgraphContext(); | ||||
Status Init(); | Status Init(); | ||||
void SetGroup(int group); | |||||
void ResetContext(const NodePtr &node); | void ResetContext(const NodePtr &node); | ||||
void Reset(); | void Reset(); | ||||
NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); | NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); | ||||
@@ -58,6 +59,7 @@ class SubgraphContext { | |||||
std::vector<TensorValue> all_outputs_; | std::vector<TensorValue> all_outputs_; | ||||
NodeDoneManager node_done_manager_; | NodeDoneManager node_done_manager_; | ||||
std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; | std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; | ||||
int group_ = -1; | |||||
}; | }; | ||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge | ||||
@@ -242,7 +242,6 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { | |||||
auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); | auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); | ||||
GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
node_state->ResetContext(group); | |||||
auto p_node_state = node_state.get(); | auto p_node_state = node_state.get(); | ||||
if (node_item.node_type == NETOUTPUT) { | if (node_item.node_type == NETOUTPUT) { | ||||
@@ -367,7 +366,6 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) { | |||||
}; | }; | ||||
GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback)); | GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback)); | ||||
node_state->ResetSchedule(); | |||||
RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End"); | RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End"); | ||||
return SUCCESS; | return SUCCESS; | ||||
}); | }); | ||||
@@ -539,6 +537,7 @@ Status SubgraphExecutor::LaunchTasks() { | |||||
Status SubgraphExecutor::ScheduleTasks(int group) { | Status SubgraphExecutor::ScheduleTasks(int group) { | ||||
GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | ||||
subgraph_context_->SetGroup(group); | |||||
auto prepare_future = std::async(std::launch::async, [&]() -> Status { | auto prepare_future = std::async(std::launch::async, [&]() -> Status { | ||||
GetContext().SetSessionId(context_->session_id); | GetContext().SetSessionId(context_->session_id); | ||||
GetContext().SetContextId(context_->context_id); | GetContext().SetContextId(context_->context_id); | ||||
@@ -21,6 +21,7 @@ | |||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "graph/build/memory/var_mem_assign_util.h" | #include "graph/build/memory/var_mem_assign_util.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/common/omg_util.h" | |||||
#include "graph/load/model_manager/model_utils.h" | #include "graph/load/model_manager/model_utils.h" | ||||
#include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
@@ -43,8 +44,9 @@ const uint64_t kProfilingBpEndLogid = 2U; | |||||
const uint64_t kProfilingIterEndLogid = 65535U; | const uint64_t kProfilingIterEndLogid = 65535U; | ||||
const int kBytes = 8; | const int kBytes = 8; | ||||
const int kDecimal = 10; | const int kDecimal = 10; | ||||
const uint8_t kStreamActiveIdx = 0; | |||||
const uint8_t kStreamActiveNum = 1; | |||||
const uint8_t kLoopEnterIdx = 0; | |||||
const uint8_t kLoopIterationIdx = 1; | |||||
const uint8_t kLoopMergeSize = 2; | |||||
const uint8_t kStreamSwitchIdx = 1; | const uint8_t kStreamSwitchIdx = 1; | ||||
const uint8_t kStreamSwitchNum = 2; | const uint8_t kStreamSwitchNum = 2; | ||||
const uint32_t kStringHeadElems = 2; | const uint32_t kStringHeadElems = 2; | ||||
@@ -57,6 +59,10 @@ const char *const kProfilingArNode = "ProfilingAllReduceNode"; | |||||
const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | ||||
const char *const kForceInfershape = "_force_infershape_when_running"; | const char *const kForceInfershape = "_force_infershape_when_running"; | ||||
const std::set<std::string> kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH }; | |||||
const std::set<std::string> kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; | |||||
const std::set<std::string> kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||||
Status SetOutputNameAttr(ComputeGraph &graph) { | Status SetOutputNameAttr(ComputeGraph &graph) { | ||||
vector<string> output_names; | vector<string> output_names; | ||||
for (const auto &node : graph.GetDirectNode()) { | for (const auto &node : graph.GetDirectNode()) { | ||||
@@ -389,7 +395,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
} | } | ||||
// cond or branch need to be prepared before the execution of IF or CASE | // cond or branch need to be prepared before the execution of IF or CASE | ||||
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | |||||
if (kExecutionDependentTypes.count(node_item.node_type) > 0) { | |||||
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | ||||
GE_CHECK_NOTNULL(src_node); | GE_CHECK_NOTNULL(src_node); | ||||
auto src_node_item = MutableNodeItem(src_node); | auto src_node_item = MutableNodeItem(src_node); | ||||
@@ -575,7 +581,7 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { | |||||
auto in_nodes = root_node->GetInAllNodes(); | auto in_nodes = root_node->GetInAllNodes(); | ||||
std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end()); | std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end()); | ||||
for (auto &in_control_node : wrapped_node->GetInControlNodes()) { | for (auto &in_control_node : wrapped_node->GetInControlNodes()) { | ||||
if (in_node_set.count(in_control_node) == 0) { | |||||
if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) { | |||||
GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); | GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); | ||||
GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); | GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); | ||||
(void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); | (void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); | ||||
@@ -2282,8 +2288,6 @@ Status HybridModelBuilder::RelinkNextIteration() { | |||||
} | } | ||||
} | } | ||||
stream_merge_op_nodes_.clear(); | |||||
next_iteration_op_nodes_.clear(); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -2371,10 +2375,12 @@ Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const No | |||||
} | } | ||||
Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) { | Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) { | ||||
const auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
if ((dst_node->GetType() == STREAMACTIVE) && (kStreamActiveTypes.count(node->GetType()) == 0)) { | |||||
GELOGI("[%s] ignore control to [%s]", node->GetName().c_str(), dst_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
@@ -2384,27 +2390,80 @@ Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem * | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item) { | |||||
// Enter --> StreamActive --> StreamMerge | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | |||||
NodeItem *dst_node_item = nullptr; | |||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
// Set Enter Control to StreamMerge as Group 0. | |||||
dst_node_item->switch_groups_.resize(kLoopMergeSize); | |||||
dst_node_item->SetMergeCtrl(node_item, kLoopEnterIdx); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item) { | |||||
// NextIteration --> StreamActive {-->} StreamMerge | |||||
std::string node_name; | |||||
for (const auto &src_node : node->GetInControlNodes()) { | |||||
GE_CHECK_NOTNULL(src_node); | |||||
if (kNextIterationOpTypes.count(src_node->GetType()) == 0) { | |||||
GELOGI("[%s] Skip Not NextIteration node [%s]", node->GetName().c_str(), src_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
if (!AttrUtils::GetStr(src_node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, node_name)) { | |||||
GELOGE(INTERNAL_ERROR, "[%s] input node [%s] expect attribute[%s] not found", | |||||
node->GetName().c_str(), src_node->GetName().c_str(), ATTR_NAME_NEXT_ITERATION.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
const auto it = stream_merge_op_nodes_.find(node_name); | |||||
if (it == stream_merge_op_nodes_.end()) { | |||||
GELOGE(INTERNAL_ERROR, "[%s] expect StreamMerge[%s] not found", node->GetName().c_str(), node_name.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
const auto &dst_node = it->second; | |||||
GE_CHECK_NOTNULL(dst_node); | |||||
NodeItem *dst_node_item = nullptr; | |||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] failed to get or create node item", | |||||
dst_node->GetName().c_str()); | |||||
// Set NextIteration Control to StreamMerge as Group 1. | |||||
dst_node_item->SetMergeCtrl(node_item, kLoopIterationIdx); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) { | Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) { | ||||
if (node_item->node_type != STREAMACTIVE) { | if (node_item->node_type != STREAMACTIVE) { | ||||
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str()); | GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
node_item->switch_groups_.resize(kStreamActiveNum); | |||||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
GE_CHECK_NOTNULL(dst_node); | |||||
if (dst_node->GetType() == STREAMMERGE) { | |||||
GELOGI("[%s] skip control node: %s", node->GetName().c_str(), dst_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
const auto ctrl_nodes = node->GetInControlNodes(); | |||||
if (ctrl_nodes.empty()) { | |||||
GELOGW("Skip no in control node: %s", node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
NodeItem *dst_node_item = nullptr; | |||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
node_item->SetCtrlSend(dst_node_item, kStreamActiveIdx); | |||||
const auto IsEnterNode = [](const NodePtr &n) { | |||||
return kEnterOpTypes.count(n->GetType()) > 0; | |||||
}; | |||||
const auto IsIterationNode = [](const NodePtr &n) { | |||||
return kNextIterationOpTypes.count(n->GetType()) > 0; | |||||
}; | |||||
if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | |||||
// Enter --> StreamActive --> StreamMerge | |||||
return CreateMergeEnterGroup(node, node_item); | |||||
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | |||||
// NextIteration --> StreamActive {-->} StreamMerge | |||||
return CreateMergeIterationGroup(node, node_item); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -2416,11 +2475,8 @@ Status HybridModelBuilder::CreateStreamSwitchGroup(const NodePtr &node, NodeItem | |||||
// Consider as two groups, group[0] set empty for false, group[1] for true. | // Consider as two groups, group[0] set empty for false, group[1] for true. | ||||
node_item->switch_groups_.resize(kStreamSwitchNum); | node_item->switch_groups_.resize(kStreamSwitchNum); | ||||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | "[%s] failed to get or create node item", dst_node->GetName().c_str()); | ||||
@@ -2447,20 +2503,17 @@ Status HybridModelBuilder::CreateStreamSwitchNGroup(const NodePtr &node, NodeIte | |||||
} | } | ||||
node_item->switch_groups_.resize(batch_num); | node_item->switch_groups_.resize(batch_num); | ||||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
std::string batch_label; | std::string batch_label; | ||||
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", node->GetName().c_str()); | |||||
if (!AttrUtils::GetStr(dst_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", dst_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
std::string::size_type pos = batch_label.rfind("_"); | std::string::size_type pos = batch_label.rfind("_"); | ||||
if (pos == std::string::npos) { | if (pos == std::string::npos) { | ||||
GELOGW("[%s] Separator not found in batch label: %s.", node->GetName().c_str(), batch_label.c_str()); | |||||
GELOGW("[%s] Separator not found in batch label: %s.", dst_node->GetName().c_str(), batch_label.c_str()); | |||||
continue; | continue; | ||||
} | } | ||||
@@ -2486,7 +2539,7 @@ Status HybridModelBuilder::CreateNextIterationGroup(const NodePtr &node, NodeIte | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
return SUCCESS; | |||||
return CreateNormalNodeGroup(node, node_item); | |||||
} | } | ||||
Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) { | Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) { | ||||
@@ -2495,11 +2548,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | "[%s] failed to get or create node item", dst_node->GetName().c_str()); | ||||
@@ -2509,11 +2559,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node | |||||
// Group switch flow by out put data. | // Group switch flow by out put data. | ||||
node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM); | node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM); | ||||
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | ||||
const auto &out_anchor = node->GetOutDataAnchor(i); | |||||
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutDataNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | "[%s] failed to get or create node item", dst_node->GetName().c_str()); | ||||
@@ -99,6 +99,8 @@ class HybridModelBuilder { | |||||
Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | ||||
Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | ||||
Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | ||||
Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); | |||||
Status CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item); | |||||
Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item); | Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item); | ||||
Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item); | Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item); | ||||
Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item); | Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item); | ||||
@@ -34,8 +34,8 @@ const std::set<std::string> kControlOpTypes{ | |||||
}; | }; | ||||
const std::set<std::string> kControlFlowOpTypes{ | const std::set<std::string> kControlFlowOpTypes{ | ||||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX, | |||||
NEXTITERATION, REFNEXTITERATION | |||||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, | |||||
LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX | |||||
}; | }; | ||||
const std::set<std::string> kMergeOpTypes{ | const std::set<std::string> kMergeOpTypes{ | ||||
@@ -401,6 +401,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
if (is_root_node_) { | if (is_root_node_) { | ||||
node_item->root_data_.emplace(this); | node_item->root_data_.emplace(this); | ||||
} | } | ||||
// If Enter feed Not Merge, take as root Node. | |||||
if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { | |||||
node_item->root_data_.emplace(this); | |||||
node_item->enter_inside_.emplace(anchor_index); | |||||
} | |||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
@@ -416,10 +421,31 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||||
if (is_root_node_) { | if (is_root_node_) { | ||||
node_item->root_ctrl_.emplace(this); | node_item->root_ctrl_.emplace(this); | ||||
} | } | ||||
// If Enter feed control signal, take as root Node. | |||||
if (kEnterOpTypes.count(node_type) > 0) { | |||||
node_item->root_ctrl_.emplace(this); | |||||
} | |||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { | |||||
if (merge_index >= switch_groups_.size()) { | |||||
GELOGE(FAILED, "[%s] group size: %zu, merge index: %u", NodeName().c_str(), switch_groups_.size(), merge_index); | |||||
return; | |||||
} | |||||
// this is StreamMerge node, node_item is StreamActive node. | |||||
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace_back(node_item); | |||||
node_item->ctrl_send_.emplace(this); | |||||
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | |||||
} | |||||
size_t NodeItem::GetMergeCtrl(uint32_t merge_index) const { | |||||
return (merge_index < switch_groups_.size()) ? switch_groups_[merge_index].size() : 0; | |||||
} | |||||
OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) { | OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) { | ||||
if (mu_ != nullptr) { | if (mu_ != nullptr) { | ||||
GELOGD("lock for %s", name_.c_str()); | GELOGD("lock for %s", name_.c_str()); | ||||
@@ -98,6 +98,8 @@ struct NodeItem { | |||||
void SetDataSend(NodeItem *node_item, int anchor_index); | void SetDataSend(NodeItem *node_item, int anchor_index); | ||||
void SetCtrlSend(NodeItem *node_item, uint32_t switch_index); | void SetCtrlSend(NodeItem *node_item, uint32_t switch_index); | ||||
void SetMergeCtrl(NodeItem *node_item, uint32_t merge_index); | |||||
size_t GetMergeCtrl(uint32_t merge_index) const; | |||||
OptionalMutexGuard MutexGuard(const std::string &name) const { | OptionalMutexGuard MutexGuard(const std::string &name) const { | ||||
return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name); | return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name); | ||||
@@ -140,6 +142,7 @@ struct NodeItem { | |||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | ||||
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | ||||
std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge. | |||||
std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
@@ -420,9 +420,8 @@ Status AiCoreOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) | |||||
} | } | ||||
Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) { | Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) { | ||||
size_t expected_arg_count = task_context.NumInputs() + task_context.NumOutputs() + | |||||
task_context.NumWorkspaces() | |||||
- output_indices_to_skip_.size(); | |||||
size_t expected_arg_count = task_context.NumInputs() + task_context.NumOutputs() + task_context.NumWorkspaces() - | |||||
output_indices_to_skip_.size(); | |||||
if (tiling_buffer_ != nullptr) { | if (tiling_buffer_ != nullptr) { | ||||
++expected_arg_count; | ++expected_arg_count; | ||||
} | } | ||||
@@ -37,7 +37,7 @@ const std::map<std::string, std::vector<uint32_t>> | |||||
{BROADCASTGRADIENTARGS, {}} | {BROADCASTGRADIENTARGS, {}} | ||||
}; | }; | ||||
const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE, NOOP}; | |||||
const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE}; | |||||
Status RefInputTask::UpdateArgs(TaskContext &) { | Status RefInputTask::UpdateArgs(TaskContext &) { | ||||
// no need update args | // no need update args | ||||
@@ -252,9 +252,16 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, | |||||
GELOGE(INTERNAL_ERROR, "[Get][Tensor] failed for name: %s", node->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "[Get][Tensor] failed for name: %s", node->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
task = MakeShared<ConstantNodeTask>(tensor); | task = MakeShared<ConstantNodeTask>(tensor); | ||||
GE_CHECK_NOTNULL(task); | GE_CHECK_NOTNULL(task); | ||||
} else if (node_type == NOOP) { | |||||
GELOGI("node %s type %s , use NoOpNodeTask.", node->GetName().c_str(), node_type.c_str()); | |||||
task = MakeShared<NoOpNodeTask>(); | |||||
if (task == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "Create NoOpNodeTask failed for NoOp node %s.", node->GetName().c_str()); | |||||
GELOGE(MEMALLOC_FAILED, "[Create][NoOpNodeTask]failed for NoOp node %s.", node->GetName().c_str()); | |||||
return MEMALLOC_FAILED; | |||||
} | |||||
} else { | } else { | ||||
GELOGE(UNSUPPORTED, "node %s type %s is not support in GeLocalNodeExecutor now.", | GELOGE(UNSUPPORTED, "node %s type %s is not support in GeLocalNodeExecutor now.", | ||||
node->GetName().c_str(), node_type.c_str()); | node->GetName().c_str(), node_type.c_str()); | ||||
@@ -280,5 +287,17 @@ Status ConstantNodeTask::ExecuteAsync(TaskContext &context, std::function<void() | |||||
GELOGD("[%s] Done execute successfully.", context.GetNodeName()); | GELOGD("[%s] Done execute successfully.", context.GetNodeName()); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status NoOpNodeTask::UpdateArgs(TaskContext &context) { | |||||
// no need to update args | |||||
return SUCCESS; | |||||
} | |||||
Status NoOpNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | |||||
GELOGD("[%s] Start execute.", context.GetNodeName()); | |||||
GE_CHK_STATUS_RET(context.TryExecuteCallback(done_callback)); | |||||
GELOGD("[%s] Done execute successfully.", context.GetNodeName()); | |||||
return SUCCESS; | |||||
} | |||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge |
@@ -80,6 +80,14 @@ class ConstantNodeTask : public NodeTask { | |||||
const TensorValue *tensor_; | const TensorValue *tensor_; | ||||
}; | }; | ||||
class NoOpNodeTask : public NodeTask { | |||||
public: | |||||
explicit NoOpNodeTask() = default; | |||||
~NoOpNodeTask() = default; | |||||
Status UpdateArgs(TaskContext &context) override; | |||||
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | |||||
}; | |||||
class GeLocalNodeExecutor : public NodeExecutor { | class GeLocalNodeExecutor : public NodeExecutor { | ||||
public: | public: | ||||
@@ -20,6 +20,7 @@ | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
@@ -201,6 +202,13 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio | |||||
GE_CHECK_NOTNULL(in_x); | GE_CHECK_NOTNULL(in_x); | ||||
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y | GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y | ||||
const auto &node_state = task_context.GetNodeState(); | |||||
if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { | |||||
node_state->RunLoopNext(); | |||||
} else if (kExitOpTypes.count(node_state->GetType()) > 0) { | |||||
node_state->RunLoopExit(); | |||||
} | |||||
if (done_callback) { | if (done_callback) { | ||||
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | ||||
} | } | ||||
@@ -61,6 +61,6 @@ class RtsTaskFactory { | |||||
REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz) | REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz) | ||||
#define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \ | #define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \ | ||||
RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> RtsNodeTaskPtr { return MakeShared<clazz>(); }) | |||||
RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()->RtsNodeTaskPtr { return MakeShared<clazz>(); }) | |||||
#endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | #endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ |
@@ -489,6 +489,11 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
} | } | ||||
void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
if (node_item_->enter_inside_.count(index) > 0) { | |||||
GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); | |||||
return; | |||||
} | |||||
auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
input_tensor->Destroy(); | input_tensor->Destroy(); | ||||
@@ -37,6 +37,9 @@ const size_t kMaxNDDimNum = 4; | |||||
const size_t kMinNDDimNum = 1; | const size_t kMinNDDimNum = 1; | ||||
const size_t kSquareBracketsSize = 2; | const size_t kSquareBracketsSize = 2; | ||||
const size_t kRangePairSize = 2; | const size_t kRangePairSize = 2; | ||||
const size_t kShapeRangeSize = 2; | |||||
const size_t kShapeRangeStrIndex = 2; | |||||
const size_t kShapeRangeStrSize = 3; | |||||
// datatype/formats from user to GE, Unified to util interface file later | // datatype/formats from user to GE, Unified to util interface file later | ||||
const std::map<std::string, ge::DataType> kOutputTypeSupportDatatype = { | const std::map<std::string, ge::DataType> kOutputTypeSupportDatatype = { | ||||
{"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; | {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; | ||||
@@ -434,7 +437,7 @@ Status ParseInputShapeRange(const std::string &shape_range, | |||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) { | std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) { | ||||
GELOGD("Input shape range %s", shape_range.c_str()); | GELOGD("Input shape range %s", shape_range.c_str()); | ||||
if (shape_range.size() < 2) { | |||||
if (shape_range.size() < kShapeRangeSize) { | |||||
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | ||||
std::vector<std::string>({shape_range, kInputShapeRangeSizeInvalid, kInputShapeRangeSample4})); | std::vector<std::string>({shape_range, kInputShapeRangeSizeInvalid, kInputShapeRangeSample4})); | ||||
GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.", | GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.", | ||||
@@ -451,7 +454,7 @@ Status ParseInputShapeRange(const std::string &shape_range, | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
for (auto &shape_range_str : shape_range_set) { | for (auto &shape_range_str : shape_range_set) { | ||||
if (shape_range_str.size() < 3) { | |||||
if (shape_range_str.size() < kShapeRangeStrSize) { | |||||
// shape_range_str should be "[2~3,1" | // shape_range_str should be "[2~3,1" | ||||
// or ",[2~3,1". because we should trim '[' or ',[' | // or ",[2~3,1". because we should trim '[' or ',[' | ||||
// so shape_range_str.size() < 3 is invalid | // so shape_range_str.size() < 3 is invalid | ||||
@@ -462,7 +465,7 @@ Status ParseInputShapeRange(const std::string &shape_range, | |||||
shape_range_str = shape_range_str.substr(1, shape_range_str.size()); | shape_range_str = shape_range_str.substr(1, shape_range_str.size()); | ||||
} | } | ||||
if (ge::StringUtils::StartWith(shape_range_str, ",")) { | if (ge::StringUtils::StartWith(shape_range_str, ",")) { | ||||
shape_range_str = shape_range_str.substr(2, shape_range_str.size()); | |||||
shape_range_str = shape_range_str.substr(kShapeRangeStrIndex, shape_range_str.size()); | |||||
} | } | ||||
// parse shape_range of single input. eg. "1~20,3,3~6,-1" | // parse shape_range of single input. eg. "1~20,3,3~6,-1" | ||||
@@ -0,0 +1,25 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
set -e | |||||
export PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../} | |||||
function main(){ | |||||
${PROJECT_HOME}/build.sh "$@" | |||||
} | |||||
main "$@" | |||||
set +e |
@@ -1,5 +1,5 @@ | |||||
#!/bin/bash | #!/bin/bash | ||||
# Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | # | ||||
# Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
# you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
@@ -13,7 +13,6 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
set -e | set -e | ||||
CLANG_FORMAT=$(which clang-format) || (echo "Please install 'clang-format' tool first"; exit 1) | CLANG_FORMAT=$(which clang-format) || (echo "Please install 'clang-format' tool first"; exit 1) | ||||
@@ -25,10 +24,10 @@ if [[ "${version}" -lt "8" ]]; then | |||||
fi | fi | ||||
CURRENT_PATH=$(pwd) | CURRENT_PATH=$(pwd) | ||||
SCRIPTS_PATH=$(dirname "$0") | |||||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../} | |||||
echo "CURRENT_PATH=$CURRENT_PATH" | echo "CURRENT_PATH=$CURRENT_PATH" | ||||
echo "SCRIPTS_PATH=$SCRIPTS_PATH" | |||||
echo "PROJECT_HOME=$PROJECT_HOME" | |||||
# print usage message | # print usage message | ||||
function usage() | function usage() | ||||
@@ -81,45 +80,46 @@ function checkopts() | |||||
checkopts "$@" | checkopts "$@" | ||||
# switch to project root path, which contains clang-format config file '.clang-format' | # switch to project root path, which contains clang-format config file '.clang-format' | ||||
cd "${SCRIPTS_PATH}/.." || exit 1 | |||||
pushd "${CURRENT_PATH}" | |||||
CHECK_LIST_FILE='__checked_files_list__' | |||||
cd "${PROJECT_HOME}" || exit 1 | |||||
CHECK_LIST_FILE='__checked_files_list__' | |||||
if [ "X${mode}" == "Xall" ]; then | |||||
find src -type f -name "*" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||||
find inc -type f -name "*" | grep "\.h$\|\.cc$" >> "${CHECK_LIST_FILE}" || true | |||||
elif [ "X${mode}" == "Xchanged" ]; then | |||||
# --diff-filter=ACMRTUXB will ignore deleted files in commit | |||||
git diff --diff-filter=ACMRTUXB --name-only | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||||
else # "X${mode}" == "Xlastcommit" | |||||
git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||||
fi | |||||
if [ "X${mode}" == "Xall" ]; then | |||||
find src -type f -name "*" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||||
find inc -type f -name "*" | grep "\.h$\|\.cc$" >> "${CHECK_LIST_FILE}" || true | |||||
elif [ "X${mode}" == "Xchanged" ]; then | |||||
# --diff-filter=ACMRTUXB will ignore deleted files in commit | |||||
git diff --diff-filter=ACMRTUXB --name-only | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||||
else # "X${mode}" == "Xlastcommit" | |||||
git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||||
fi | |||||
CHECK_RESULT_FILE=__code_format_check_result__ | |||||
echo "0" > "$CHECK_RESULT_FILE" | |||||
CHECK_RESULT_FILE=__code_format_check_result__ | |||||
echo "0" > "$CHECK_RESULT_FILE" | |||||
# check format of files modified in the lastest commit | |||||
while read line; do | |||||
BASE_NAME=$(basename "${line}") | |||||
TEMP_FILE="__TEMP__${BASE_NAME}" | |||||
cp "${line}" "${TEMP_FILE}" | |||||
${CLANG_FORMAT} -i "${TEMP_FILE}" | |||||
set +e | |||||
diff "${TEMP_FILE}" "${line}" | |||||
ret=$? | |||||
set -e | |||||
rm "${TEMP_FILE}" | |||||
if [[ "${ret}" -ne 0 ]]; then | |||||
echo "File ${line} is not formated, please format it." | |||||
echo "1" > "${CHECK_RESULT_FILE}" | |||||
break | |||||
fi | |||||
done < "${CHECK_LIST_FILE}" | |||||
# check format of files modified in the lastest commit | |||||
while read line; do | |||||
BASE_NAME=$(basename "${line}") | |||||
TEMP_FILE="__TEMP__${BASE_NAME}" | |||||
cp "${line}" "${TEMP_FILE}" | |||||
${CLANG_FORMAT} -i "${TEMP_FILE}" | |||||
set +e | |||||
diff "${TEMP_FILE}" "${line}" | |||||
ret=$? | |||||
set -e | |||||
rm "${TEMP_FILE}" | |||||
if [[ "${ret}" -ne 0 ]]; then | |||||
echo "File ${line} is not formated, please format it." | |||||
echo "1" > "${CHECK_RESULT_FILE}" | |||||
break | |||||
fi | |||||
done < "${CHECK_LIST_FILE}" | |||||
result=$(cat "${CHECK_RESULT_FILE}") | |||||
rm "${CHECK_RESULT_FILE}" | |||||
rm "${CHECK_LIST_FILE}" | |||||
popd | |||||
result=$(cat "${CHECK_RESULT_FILE}") | |||||
rm "${CHECK_RESULT_FILE}" | |||||
rm "${CHECK_LIST_FILE}" | |||||
cd "${CURRENT_PATH}" || exit 1 | |||||
if [[ "X${result}" == "X0" ]]; then | if [[ "X${result}" == "X0" ]]; then | ||||
echo "Check PASS: specified files are well formated!" | echo "Check PASS: specified files are well formated!" | ||||
fi | fi | ||||
@@ -0,0 +1,90 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
set -e | |||||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||||
function help(){ | |||||
cat <<-EOF | |||||
Usage: ge clean [OPTIONS] | |||||
Options: | |||||
-b, --build Clean build dir | |||||
-d, --docs Clean generate docs | |||||
-i, --install Clean dependenices | |||||
-a, --all Clean all | |||||
-h, --help | |||||
EOF | |||||
} | |||||
function clean_relative_dir(){ | |||||
rm -rf "${PROJECT_HOME}/${1:-output}" | |||||
} | |||||
function parse_args(){ | |||||
parsed_args=$(getopt -a -o bdiah --long build,docs,install,all,help -- "$@") || { | |||||
help | |||||
exit 1 | |||||
} | |||||
if [ $# -lt 1 ]; then | |||||
clean_relative_dir "build" | |||||
clean_relative_dir "output" | |||||
exit 1 | |||||
fi | |||||
eval set -- "$parsed_args" | |||||
while true; do | |||||
case "$1" in | |||||
-b | --build) | |||||
clean_relative_dir "build" | |||||
clean_relative_dir "output" | |||||
;; | |||||
-d | --docs) | |||||
clean_relative_dir "docs/doxygen" | |||||
;; | |||||
-i | --install) | |||||
clean_relative_dir "deps" | |||||
;; | |||||
-a | --all) | |||||
clean_relative_dir "deps" | |||||
clean_relative_dir "build" | |||||
clean_relative_dir "output" | |||||
clean_relative_dir "docs/doxygen" | |||||
;; | |||||
-h | --help) | |||||
help | |||||
;; | |||||
--) | |||||
shift; break; | |||||
;; | |||||
*) | |||||
help; exit 1 | |||||
;; | |||||
esac | |||||
shift | |||||
done | |||||
} | |||||
function main(){ | |||||
parse_args "$@" | |||||
} | |||||
main "$@" | |||||
set +e |
@@ -0,0 +1,115 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
set -e | |||||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||||
function help(){ | |||||
cat <<-EOF | |||||
Usage: ge config [OPTIONS] | |||||
update server config for ge, you need input all config info (ip, user, password) | |||||
Options: | |||||
-i, --ip Config ip config | |||||
-u, --user Config user name | |||||
-p, --password Config password | |||||
-h, --help | |||||
Example: ge config -i=121.36.**.** -u=asc**, -p=Asc***\#@\!\$ (Need add escape character \ before special charactor $、#、!) | |||||
EOF | |||||
} | |||||
function write_config_file(){ | |||||
local IP=$1 | |||||
local USER=$2 | |||||
local PASSWORD=$3 | |||||
if [[ -z "$IP" ]] || [[ -z "$USER" ]] || [[ -z "$USER" ]]; then | |||||
echo "You need input all info (ip, user,password)obout server config !!!" | |||||
help | |||||
exit 1 | |||||
fi | |||||
local PASSWORD=${PASSWORD//!/\\!} | |||||
local PASSWORD=${PASSWORD//#/\\#} | |||||
local PASSWORD=${PASSWORD/\$/\\\$} | |||||
local SERVER_CONFIG_FILE=${PROJECT_HOME}/scripts/config/server_config.sh | |||||
[ -n "${SERVER_CONFIG_FILE}" ] && rm -rf "${SERVER_CONFIG_FILE}" | |||||
cat>${SERVER_CONFIG_FILE}<<-EOF | |||||
SERVER_PATH=http://${IP}/package/etrans | |||||
DEP_USER=${USER} | |||||
DEP_PASSWORD=${PASSWORD} | |||||
EOF | |||||
} | |||||
function parse_args(){ | |||||
parsed_args=$(getopt -a -o i::u::p::h --long ip::,user::,password::,help -- "$@") || { | |||||
help | |||||
exit 1 | |||||
} | |||||
if [ $# -lt 1 ]; then | |||||
help | |||||
exit 1 | |||||
fi | |||||
local IP= | |||||
local USER= | |||||
local PASSWORD= | |||||
eval set -- "$parsed_args" | |||||
while true; do | |||||
case "$1" in | |||||
-i | --ip) | |||||
IP=$2 | |||||
;; | |||||
-u | --user) | |||||
USER=$2 | |||||
;; | |||||
-p | --password) | |||||
PASSWORD=$2 | |||||
;; | |||||
-h | --help) | |||||
help; exit; | |||||
;; | |||||
--) | |||||
shift; break; | |||||
;; | |||||
*) | |||||
help; exit 1 | |||||
;; | |||||
esac | |||||
shift 2 | |||||
done | |||||
write_config_file $IP $USER $PASSWORD | |||||
} | |||||
function main(){ | |||||
parse_args "$@" | |||||
} | |||||
main "$@" | |||||
set +e |
@@ -0,0 +1,136 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
set -e | |||||
function help(){ | |||||
cat <<-EOF | |||||
Usage: ge cov [OPTIONS] | |||||
Options: | |||||
-a, --all Full coverage | |||||
-i, --increment Increment coverage | |||||
-d, --directory Coverage of directory | |||||
-h, --help | |||||
EOF | |||||
} | |||||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||||
ALL_COV_GEN_PATH=${PROJECT_HOME}/cov/all | |||||
DIFF_FILE_PATH=${PROJECT_HOME}/cov/diff | |||||
DIFF_FILE_NAME=${DIFF_FILE_PATH}/inc_change_diff.txt | |||||
function process_diff_format(){ | |||||
sed -i "s/--- a/--- \/code\/Turing\/graphEngine/g" ${DIFF_FILE_NAME} | |||||
sed -i "s/+++ b/+++ \/code\/Turing\/graphEngine/g" ${DIFF_FILE_NAME} | |||||
} | |||||
function add_cov_generate(){ | |||||
addlcov --diff ${ALL_COV_GEN_PATH}/coverage.info ${DIFF_FILE_NAME} -o ${PROJECT_HOME}/cov/diff/inc_coverage.info | |||||
} | |||||
function gen_add_cov_html(){ | |||||
genhtml --prefix ${PROJECT_HOME} -o ${PROJECT_HOME}/cov/diff/html ${PROJECT_HOME}/cov/diff/inc_coverage.info --legend -t CHG --no-branch-coverage --no-function-coverage | |||||
} | |||||
function increment_cov_for_directory(){ | |||||
[ -n "${DIFF_FILE_PATH}" ] && rm -rf "${DIFF_FILE_PATH}" | |||||
mkdir -p ${DIFF_FILE_PATH} | |||||
git diff HEAD -- $1 >>${DIFF_FILE_NAME} | |||||
process_diff_format | |||||
add_cov_generate | |||||
gen_add_cov_html | |||||
} | |||||
function run_all_coverage(){ | |||||
[ -n "${ALL_COV_GEN_PATH}" ] && rm -rf ${ALL_COV_GEN_PATH} | |||||
mkdir -p ${ALL_COV_GEN_PATH} | |||||
pushd "${PWD}" >/dev/null | |||||
cd ${PROJECT_HOME} | |||||
lcov -c -d build/tests/ut/ge -d build/tests/ut/common/graph/ -o ${ALL_COV_GEN_PATH}/tmp.info | |||||
lcov -r ${ALL_COV_GEN_PATH}/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' '/usr/include/*' '*/metadef/*' '*/parser/*' -o ${ALL_COV_GEN_PATH}/coverage.info | |||||
cd ${ALL_COV_GEN_PATH} | |||||
genhtml coverage.info | |||||
popd >/dev/null | |||||
} | |||||
function do_coverage_run(){ | |||||
local cov_mode=$1 | |||||
local directory_dir=$2 | |||||
run_all_coverage | |||||
if [ "$cov_mode" = "all" ]; then | |||||
exit 1 | |||||
elif [ -n "$directory_dir" ]; then | |||||
increment_cov_for_directory $directory_dir | |||||
else | |||||
increment_cov_for_directory "ge" | |||||
fi | |||||
} | |||||
function parse_args(){ | |||||
parsed_args=$(getopt -a -o aid::h --long all,increment,directory::,help -- "$@") || { | |||||
help | |||||
exit 1 | |||||
} | |||||
if [ $# -lt 1 ]; then | |||||
run_all_coverage | |||||
exit 1 | |||||
fi | |||||
local cov_mode="increment" | |||||
local directory_dir= | |||||
eval set -- "$parsed_args" | |||||
while true; do | |||||
case "$1" in | |||||
-a | --all) | |||||
cov_mode="all" | |||||
;; | |||||
-i | --increment) | |||||
;; | |||||
-d | --directory) | |||||
directory_dir=$2 | |||||
shift | |||||
;; | |||||
-h | --help) | |||||
help; exit 1; | |||||
;; | |||||
--) | |||||
shift; break; | |||||
;; | |||||
*) | |||||
help; exit 1; | |||||
;; | |||||
esac | |||||
shift | |||||
done | |||||
do_coverage_run $cov_mode $directory_dir | |||||
} | |||||
function main(){ | |||||
parse_args "$@" | |||||
} | |||||
main "$@" | |||||
set +e |
@@ -0,0 +1,87 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
set -e | |||||
function help(){ | |||||
cat <<-EOF | |||||
Usage: ge docs [OPTIONS] | |||||
Options: | |||||
-b, --brief Build brief docs | |||||
-a, --all Build all docs | |||||
-h, --help | |||||
EOF | |||||
} | |||||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||||
BRIEF_DOXYFILE_PATH=${PROJECT_HOME}/scripts/docs/Doxyfile_brief | |||||
ALL_DOXYFILE_PATH=${PROJECT_HOME}/scripts/docs/Doxyfile_all | |||||
function build_brief_docs(){ | |||||
rm -rf "${PROJECT_HOME}/docs/doxygen" | |||||
doxygen ${BRIEF_DOXYFILE_PATH} | |||||
} | |||||
function build_all_docs(){ | |||||
rm -rf "${PROJECT_HOME}/docs/doxygen" | |||||
doxygen ${ALL_DOXYFILE_PATH} | |||||
} | |||||
function parse_args(){ | |||||
parsed_args=$(getopt -a -o bah --long brief,all,help -- "$@") || { | |||||
help | |||||
exit 1 | |||||
} | |||||
if [ $# -lt 1 ]; then | |||||
build_all_docs | |||||
exit 1 | |||||
fi | |||||
eval set -- "$parsed_args" | |||||
while true; do | |||||
case "$1" in | |||||
-b | --brief) | |||||
build_brief_docs | |||||
;; | |||||
-a | --all) | |||||
build_all_docs | |||||
;; | |||||
-h | --help) | |||||
help; exit 1; | |||||
;; | |||||
--) | |||||
shift; break; | |||||
;; | |||||
*) | |||||
help; exit 1; | |||||
;; | |||||
esac | |||||
shift | |||||
done | |||||
} | |||||
function main(){ | |||||
parse_args "$@" | |||||
} | |||||
main "$@" | |||||
set +e |
@@ -0,0 +1,42 @@ | |||||
# this dockerfile used for graphengine build | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
FROM ubuntu:18.04 | |||||
RUN apt-get update \ | |||||
&& apt-get install -y git g++ wget unzip clang-format-9 build-essential lcov vim | |||||
# install for doxygen | |||||
RUN apt-get install -y graphviz doxygen | |||||
# install for graph ensy engine | |||||
RUN cpan install -y Graph::Easy | |||||
RUN wget https://cmake.org/files/v3.16/cmake-3.16.7-Linux-x86_64.tar.gz | |||||
RUN mkdir -p /opt/cmake-3.16.7 \ | |||||
&& tar -xvf cmake-3.16.7-Linux-x86_64.tar.gz -C /opt/cmake-3.16.7 --strip-components=1 \ | |||||
&& ln -sf /opt/cmake-3.16.7/bin/* /usr/bin/ \ | |||||
&& mv /usr/bin/clang-format-9 /usr/bin/clang-format | |||||
RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_lcov.tar.gz \ | |||||
&& mkdir -p /opt/addlcov1.0.0 \ | |||||
&& tar -xvf add_lcov.tar.gz -C /opt/addlcov1.0.0 \ | |||||
&& mv /opt/addlcov1.0.0/lcov-add_lcov/bin/lcov /usr/bin/addlcov | |||||
ENV PROJECT_HOME=/code/Turing/graphEngine | |||||
RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | |||||
@@ -0,0 +1,146 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
set -e | |||||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||||
MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||||
DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | |||||
DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | |||||
DOCKER_IMAGE_TAG=ge_build_env.1.0.6 | |||||
DOCKER_IAMGE_NAME=joycode2art/turing | |||||
DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | |||||
if [ "$(uname)" == "Darwin" ]; then | |||||
#running on Mac OS | |||||
docker_cmd=docker | |||||
MOUNT_PROJECT_HOME=${MOUNT_PROJECT_HOME} | |||||
docker_work_dir=/code/Turing/graphEngine | |||||
docker_bash_dir=/bin/bash | |||||
elif [ "$(expr substr "$(uname -s)" 1 10)" == "MINGW32_NT" ] || [ "$(expr substr "$(uname -s)" 1 10)" == "MINGW64_NT" ]; then | |||||
#running on Windows | |||||
docker_cmd="winpty docker" | |||||
MOUNT_PROJECT_HOME=/${MOUNT_PROJECT_HOME} | |||||
docker_work_dir=//code/Turing/graphEngine | |||||
docker_bash_dir=//bin/bash | |||||
elif [ "$(expr substr "$(uname -s)" 1 5)" == "Linux" ]; then | |||||
#running on Linux | |||||
docker_cmd=docker | |||||
MOUNT_PROJECT_HOME=${PROJECT_HOME} | |||||
docker_work_dir=/code/Turing/graphEngine | |||||
docker_bash_dir=/bin/bash | |||||
fi | |||||
function build_docker_image(){ | |||||
if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | |||||
$docker_cmd build -t ${DOCKER_FULL_IMAGE_NAME} ${PROJECT_HOME}/scripts/env | |||||
else | |||||
echo "docker image for graph engine build is build ok...." | |||||
fi | |||||
} | |||||
function pull_docker_image(){ | |||||
$docker_cmd pull $DOCKER_FULL_IMAGE_NAME | |||||
} | |||||
function enter_docker_env(){ | |||||
if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | |||||
echo "please run 'ge env --pull' to download images first!" | |||||
elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||||
$docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||||
elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||||
$docker_cmd start ${DOCKER_BUILD_ENV_NAME} | |||||
$docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | |||||
else | |||||
$docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | |||||
fi | |||||
} | |||||
function resert_docker_env(){ | |||||
if test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||||
echo "no runing container for graphengine build" | |||||
elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||||
$docker_cmd rm ${DOCKER_BUILD_ENV_NAME} | |||||
else | |||||
$docker_cmd stop ${DOCKER_BUILD_ENV_NAME} | |||||
$docker_cmd rm ${DOCKER_BUILD_ENV_NAME} | |||||
fi | |||||
} | |||||
function help(){ | |||||
cat <<-EOF | |||||
Usage: ge env [OPTIONS] | |||||
Prepare for docker env for build and test | |||||
Options: | |||||
-b, --build Build docker image | |||||
-p, --pull Pull docker image | |||||
-e, --enter Enter container | |||||
-r, --reset Reset container | |||||
-h, --help | |||||
EOF | |||||
} | |||||
function parse_args(){ | |||||
parsed_args=$(getopt -a -o bperh --long build,pull,enter,resethelp -- "$@") || { | |||||
help | |||||
exit 1 | |||||
} | |||||
if [ $# -lt 1 ]; then | |||||
pull_docker_image | |||||
enter_docker_env | |||||
exit 1 | |||||
fi | |||||
eval set -- "$parsed_args" | |||||
while true; do | |||||
case "$1" in | |||||
-b | --build) | |||||
build_docker_image | |||||
;; | |||||
-p | --pull) | |||||
pull_docker_image | |||||
;; | |||||
-e | --enter) | |||||
enter_docker_env | |||||
;; | |||||
-r | --reset) | |||||
resert_docker_env | |||||
;; | |||||
-h | --help) | |||||
help | |||||
;; | |||||
--) | |||||
shift; break; | |||||
;; | |||||
*) | |||||
help; exit 1 | |||||
;; | |||||
esac | |||||
shift | |||||
done | |||||
} | |||||
function main(){ | |||||
parse_args "$@" | |||||
} | |||||
main "$@" | |||||
set -e |
@@ -1,5 +1,5 @@ | |||||
#!/bin/bash | #!/bin/bash | ||||
# Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | # | ||||
# Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
# you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
@@ -24,11 +24,12 @@ if [[ "${version}" -lt "8" ]]; then | |||||
exit 1 | exit 1 | ||||
fi | fi | ||||
CURRENT_PATH=$(pwd) | CURRENT_PATH=$(pwd) | ||||
SCRIPTS_PATH=$(dirname "$0") | |||||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||||
echo "CURRENT_PATH=${CURRENT_PATH}" | echo "CURRENT_PATH=${CURRENT_PATH}" | ||||
echo "SCRIPTS_PATH=${SCRIPTS_PATH}" | |||||
echo "PROJECT_HOME=${PROJECT_HOME}" | |||||
# print usage message | # print usage message | ||||
function usage() | function usage() | ||||
@@ -81,27 +82,28 @@ function checkopts() | |||||
checkopts "$@" | checkopts "$@" | ||||
# switch to project root path, which contains clang-format config file '.clang-format' | # switch to project root path, which contains clang-format config file '.clang-format' | ||||
cd "${SCRIPTS_PATH}/.." || exit 1 | |||||
FMT_FILE_LIST='__format_files_list__' | |||||
if [[ "X${mode}" == "Xall" ]]; then | |||||
find src -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||||
find inc -type f -name "*" | grep "\.h$\|\.cc$" >> "${FMT_FILE_LIST}" || true | |||||
elif [[ "X${mode}" == "Xchanged" ]]; then | |||||
# --diff-filter=ACMRTUXB will ignore deleted files in commit | |||||
git diff --diff-filter=ACMRTUXB --name-only | grep "^inc\|^src" | grep "\.h$\|\.cc$" >> "${FMT_FILE_LIST}" || true | |||||
else # "X${mode}" == "Xlastcommit" | |||||
git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||||
fi | |||||
while read line; do | |||||
if [ -f "${line}" ]; then | |||||
${CLANG_FORMAT} -i "${line}" | |||||
fi | |||||
done < "${FMT_FILE_LIST}" | |||||
pushd "${CURRENT_PATH}" | |||||
cd "${PROJECT_HOME}" || exit 1 | |||||
FMT_FILE_LIST='__format_files_list__' | |||||
if [[ "X${mode}" == "Xall" ]]; then | |||||
find src -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||||
find inc -type f -name "*" | grep "\.h$\|\.cc$" >> "${FMT_FILE_LIST}" || true | |||||
elif [[ "X${mode}" == "Xchanged" ]]; then | |||||
# --diff-filter=ACMRTUXB will ignore deleted files in commit | |||||
git diff --diff-filter=ACMRTUXB --name-only | grep "^inc\|^src" | grep "\.h$\|\.cc$" >> "${FMT_FILE_LIST}" || true | |||||
else # "X${mode}" == "Xlastcommit" | |||||
git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||||
fi | |||||
while read line; do | |||||
if [ -f "${line}" ]; then | |||||
${CLANG_FORMAT} -i "${line}" | |||||
fi | |||||
done < "${FMT_FILE_LIST}" | |||||
rm "${FMT_FILE_LIST}" | |||||
cd "${CURRENT_PATH}" || exit 1 | |||||
rm "${FMT_FILE_LIST}" | |||||
popd | |||||
echo "Specified cpp source files have been format successfully." | echo "Specified cpp source files have been format successfully." |
@@ -0,0 +1,77 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
GE_BASH_HOME=$(dirname "$0") | |||||
export PROJECT_HOME=${PROJECT_HOME:-${GE_BASH_HOME}/../} | |||||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||||
function help(){ | |||||
cat <<-EOF | |||||
Usage: ge COMMANDS | |||||
Run ge commands | |||||
Commands: | |||||
env Prepare docker env | |||||
config Config dependencies server | |||||
update Update dependencies | |||||
format Format code | |||||
build Build code | |||||
test Run test of UT/ST | |||||
cov Run Coverage | |||||
docs Generate documents | |||||
clean Clean | |||||
EOF | |||||
} | |||||
function ge_error() { | |||||
echo "Error: $*" >&2 | |||||
help | |||||
exit 1 | |||||
} | |||||
function main(){ | |||||
if [ $# -eq 0 ]; then | |||||
help; exit 0 | |||||
fi | |||||
local cmd=$1 | |||||
local shell_cmd= | |||||
shift | |||||
case "$cmd" in | |||||
-h|--help) | |||||
help; exit 0 | |||||
;; | |||||
build) | |||||
shell_cmd=${PROJECT_HOME}/build.sh | |||||
;; | |||||
*) | |||||
shell_cmd=$GE_BASH_HOME/$cmd/ge_$cmd.sh | |||||
;; | |||||
esac | |||||
[ -e $shell_cmd ] || { | |||||
ge_error "ge $shell_cmd is not found" | |||||
} | |||||
$shell_cmd "$@" | |||||
} | |||||
main "$@" | |||||
@@ -0,0 +1,331 @@ | |||||
# graph engine 个人开发工具链使用说明 | |||||
GE开发者工具链是graph engine中的一套面向个人开发者的自动化脚本工具链。 | |||||
目前支持基于容器开发环境准备、构建依赖的自动下载安装和配置、代码格式化、编译、测试、代码覆盖率检查、文档生成等一系列开发者常用功能。 | |||||
## 前置准备 | |||||
下面是使用GE开发者工具链,需要手动进行的前置准备; | |||||
下列是经过验证后的性能最佳推荐配置: | |||||
1. 操作系统,以下任选其一: | |||||
- 原生的Linux操作系统,如ubuntu; | |||||
- Windows操作系统,推荐安装WSL的ubuntu系统,强烈建议登录WSL内直接下载代码,不要挂卷(构建性能差)! | |||||
- MAC OS; | |||||
2. docker安装: | |||||
- docker安装成功,并且相关镜像源已经设置正确,可正常下载外部镜像。 | |||||
3. OS支持的命令行工具: 原生Linux or WSL shell; | |||||
可用但不推荐的配置: | |||||
- 在windows中直接安装docker,采用仿linux bash(Cygwin,minGW等)执行ge工具链; | |||||
(使用这种方式也可以执行所有GE工具链的操作,但是因为windows和容器异构内核的文件访问限制会导致构建速度比较慢) | |||||
## 快速上手 | |||||
GE工具链对应的脚本在scripts下,可以按照下面流程来执行: | |||||
1. 进入到scripts目录: | |||||
```sh | |||||
$ cd ./scripts | |||||
``` | |||||
2.执行`ge env`自动下载容器环境,并登陆到环境中 | |||||
```sh | |||||
$ ./ge.sh env | |||||
``` | |||||
3.配置外部依赖服务器信息 | |||||
```sh | |||||
ge config -i=121.36.**.** -u=asc**, -p=Asc***\#@\!$ (Need add escape character \ before special charactor $、#、!) | |||||
``` | |||||
4.下载和安装构建所依赖的外部库 | |||||
```sh | |||||
$ ge update | |||||
``` | |||||
(注:进入容器后,`ge`命令已经自动注册进系统,因此容器内不需要写脚本全称) | |||||
5.执行测试,默认执行单元测试用例,`ge test`会自动触发构建 | |||||
```sh | |||||
$ ge test | |||||
``` | |||||
## 详细用法 | |||||
在scripts目录下,运行./ge.sh -h 即可查看到所有的子命令集合。 | |||||
```sh | |||||
$ ./ge.sh -h | |||||
Usage: ge COMMANDS | |||||
Run ge commands | |||||
Commands: | |||||
env Prepare docker env | |||||
config Config dependencies server | |||||
update Update dependencies | |||||
format Format code | |||||
lint Static verify | |||||
build Build code | |||||
test Run test of UT/ST | |||||
cov Run Coverage | |||||
docs Generate documents | |||||
clean Clean | |||||
``` | |||||
脚本内置的每个子命令,代表一个独立功能;每个子命令还提供了二级参数用于灵活指定执行方式。 | |||||
每个子命令可以通过`-h`查看支持的可配参数。 | |||||
例如查询`env`子命令的参数,可以使用如下命令: | |||||
```sh | |||||
$ ./ge.sh env -h | |||||
``` | |||||
每个子命令在不带参数时,会有一个默认的行为。 | |||||
### `ge env` | |||||
该命令用于准备构建和测试使用的容器环境,具体包含参数如下: | |||||
``` | |||||
$ ./ge.sh env -h | |||||
Usage: ge env [OPTIONS] | |||||
Prepare for docker env for build and test | |||||
Options: | |||||
-b, --build Build docker image | |||||
-p, --pull Pull docker image | |||||
-e, --enter Enter container | |||||
-r, --reset Reset container | |||||
-h, --help | |||||
``` | |||||
参数详细解释: | |||||
- `-b -- build`: 依据“scripts/env/Dockerfile”生成需要运行的容器镜像; | |||||
- `-p -- pull` : 从本地配置的容器中央仓拉取需要的的容器镜像; | |||||
- `-e -- enter`: 在本地已有容器镜像的前提下,登录容器运行环境; | |||||
- `-r -- reset`: 删除本地运行的容器镜像环境; | |||||
默认:从中央容器仓拉取对应的容器镜像,运行实例并登陆。 | |||||
### `ge config` | |||||
配置外部依赖服务器,具体参数如下: | |||||
```sh | |||||
$ ge config -h | |||||
Usage: ge config [OPTIONS] | |||||
update server config for ge, you need input all config info (ip, user, password) | |||||
Options: | |||||
-i, --ip Config ip config | |||||
-u, --user Config user name | |||||
-p, --password Config password | |||||
-h, --help | |||||
Example: ge config -i=121.36.**.** -u=asc**, -p=Asc***\#@\!$ (Need add escape character \ before special charactor $、#、!) | |||||
``` | |||||
参数详细解释: | |||||
- `-i, --ip` : 配置依赖库服务器IP地址; | |||||
- `-u, --usr` : 配置依赖库服务器用户名; | |||||
- `-p, --password` : 配置依赖库地址; | |||||
默认:打印帮助信息。 | |||||
### `ge update` | |||||
安装graph engine构建所需的外部依赖库,具体参数如下: | |||||
```sh | |||||
$ ge update -h | |||||
Usage: ge update [OPTIONS] | |||||
update dependencies of build and test | |||||
Options: | |||||
-d, --download Download dependencies | |||||
-i, --install Install dependencies | |||||
-c, --clear Clear dependencies | |||||
-h, --help | |||||
``` | |||||
参数详细解释: | |||||
- `-d, --download` : 下载构建需要外部依赖库; | |||||
- `-i, --install` : 安装外部依赖包到对应位置; | |||||
- `-c, --clear` : 清除下载的外部依赖包; | |||||
默认:根据"scripts/update/deps_config.sh"的配置下载外部依赖库并安装到对应目录。 | |||||
(注:请确保“scripts/update/server_config.sh”中的服务器地址、用户名、密码已经配置) | |||||
### `ge format` | |||||
使用clang-format进行代码格式化,具体参数如下: | |||||
```sh | |||||
$ ge format -h | |||||
Options: | |||||
-a format of all files | |||||
-c format of the files changed compared to last commit, default case | |||||
-l format of the files changed in last commit | |||||
-h Print usage | |||||
``` | |||||
参数详细解释: | |||||
- `-a` : 格式化所有代码; | |||||
- `-c` : 只格式化本次修改的代码; | |||||
- `-l` : 格式化上次提交的代码; | |||||
默认:格式化本次修改代码。 | |||||
### `ge lint` | |||||
使用clang-format进行代码格式化检查,具体参数如下: | |||||
```sh | |||||
$ ge lint -h | |||||
Options: | |||||
-a Check code format of all files, default case | |||||
-c Check code format of the files changed compared to last commit | |||||
-l Check code format of the files changed in last commit | |||||
-h Print usage | |||||
``` | |||||
参数详细解释: | |||||
- `-a` : 检查所有代码格式; | |||||
- `-c` : 只检查修改的代码格式; | |||||
- `-l` : 检查上次提交的代码格式; | |||||
默认:检查本次修改代码格式。 | |||||
### `ge build` | |||||
执行构建 (注:调用原有build.sh脚本,改造中...); | |||||
### `ge test` | |||||
构建和运行测试用例,目前可以支持参数如下: | |||||
```sh | |||||
$ ge test -h | |||||
Usage: ge test [OPTIONS] | |||||
Options: | |||||
-u, --unit Run unit Test | |||||
-c, --component Run component Test | |||||
-h, --help | |||||
``` | |||||
参数详细解释: | |||||
- `-u, --unit` : 执行单元测试 | |||||
- `-c, --component` : 执行组件测试 | |||||
默认:执行单元测试。 | |||||
### `ge cov` | |||||
执行代码覆盖率检查, 支持全量覆盖和增量覆盖的功能,该命令需要已经跑完测试用例,目前可以支持参数如下: | |||||
```sh | |||||
$ ge cov -h | |||||
Usage: ge cov [OPTIONS] | |||||
Options: | |||||
-a, --all Full coverage | |||||
-i, --increment Increment coverage | |||||
-d, --directory Coverage of directory | |||||
-h, --help | |||||
``` | |||||
参数详细解释: | |||||
- `-a, --all` : 执行全量覆盖率统计; | |||||
- `-i, --increment` : 执行增量覆盖率检查,默认是分析未提交修改的代码覆盖率(如果存在新增加的git未跟踪文件,需要先git add 添加进来才可以); | |||||
- `-d, --directory` : 代码进行增量覆盖率检查的代码路径,支持传入路径参数; | |||||
默认:执行增量覆盖率检查; | |||||
下面的命令演示了如何检查ge目录下所有代码的增量覆盖率: | |||||
```sh | |||||
$ ge cov -d=ge | |||||
``` | |||||
### `ge docs` | |||||
Doxygen文档生成,包含代码逻辑和物理结构和关系,方便阅读和理解代码;目前可以支持参数如下: | |||||
```sh | |||||
$ ge docs -h | |||||
Usage: ge docs [OPTIONS] | |||||
Options: | |||||
-b, --brief Build brief docs | |||||
-a, --all Build all docs | |||||
-h, --help | |||||
``` | |||||
参数详细解释: | |||||
- `-b, --brief` : 生成简要文档,忽略部分关系图生成,速度快; | |||||
- `-a, --all` : 生成全量文档,包含各种代码关系图,速度相对慢; | |||||
默认: 生成全量代码文档。 | |||||
### `ge clean` | |||||
清除各种下载或生成的中间文件,目前支持的参数如下: | |||||
```sh | |||||
$ ge clean -h | |||||
Usage: ge clean [OPTIONS] | |||||
Options: | |||||
-b, --build Clean build dir | |||||
-d, --docs Clean generate docs | |||||
-i, --install Clean dependenices | |||||
-a, --all Clean all | |||||
-h, --help | |||||
``` | |||||
参数详细解释: | |||||
- `-b, --build` : 清除生成的编译构建临时文件; | |||||
- `-d, --docs` : 清除生成的文档临时文件; | |||||
- `-i, --install` : 清除安装的依赖文件; | |||||
- `-a, --all` : 清除所有下载和生成的临时文件; | |||||
默认:清除编译构建产生临时文件。 | |||||
## Follow us | |||||
工具链的功能还在不断完善中,有问题请提issue,谢谢! |
@@ -0,0 +1,80 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
set -e | |||||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||||
function help(){ | |||||
cat <<-EOF | |||||
Usage: ge test [OPTIONS] | |||||
Options: | |||||
-u, --unit Run unit Test | |||||
-c, --component Run component Test | |||||
-h, --help | |||||
EOF | |||||
} | |||||
function unit_test(){ | |||||
${PROJECT_HOME}/build.sh -u | |||||
} | |||||
function component_test(){ | |||||
${PROJECT_HOME}/build.sh -s | |||||
} | |||||
function parse_args(){ | |||||
parsed_args=$(getopt -a -o uch --long unit,component,help -- "$@") || { | |||||
help | |||||
exit 1 | |||||
} | |||||
if [ $# -lt 1 ]; then | |||||
unit_test | |||||
exit 1 | |||||
fi | |||||
eval set -- "$parsed_args" | |||||
while true; do | |||||
case "$1" in | |||||
-u | --unit) | |||||
unit_test | |||||
;; | |||||
-c | --component) | |||||
component_test | |||||
;; | |||||
-h | --help) | |||||
help | |||||
;; | |||||
--) | |||||
shift; break; | |||||
;; | |||||
*) | |||||
help; exit 1 | |||||
;; | |||||
esac | |||||
shift | |||||
done | |||||
} | |||||
function main(){ | |||||
parse_args "$@" | |||||
} | |||||
main "$@" | |||||
set +e |
@@ -0,0 +1,47 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
set -e | |||||
SERVER_CONFIG_FILE=${PROJECT_HOME}/scripts/config/server_config.sh | |||||
[ -e $SERVER_CONFIG_FILE ] || { | |||||
echo "You have not config dependencies account info first !!!!!" | |||||
${PROJECT_HOME}/scripts/config/ge_config.sh -h | |||||
exit 1; | |||||
} | |||||
source scripts/config/server_config.sh | |||||
CPU_ARCH=ubuntu18.04.x86_64 | |||||
DRIVER_VERSION=20.2.0 | |||||
CHIP_NAME=A800-9010 | |||||
PRODUCT_VERSION=driver_C76_TR5 | |||||
DRIVER_NAME=npu-driver | |||||
DRIVER_RUN_NAME=${CHIP_NAME}-${DRIVER_NAME}_${DRIVER_VERSION}_ubuntu18.04-x86_64.run | |||||
DEV_TOOLS_VERSION=1.78.t10.0.b100 | |||||
export ATC_RUN_NAME=Ascend-atc-${DEV_TOOLS_VERSION}-${CPU_ARCH}.run | |||||
export ACL_RUN_NAME=Ascend-acllib-${DEV_TOOLS_VERSION}-${CPU_ARCH}.run | |||||
export FWKACL_RUN_NAME=Ascend-fwkacllib-${DEV_TOOLS_VERSION}-${CPU_ARCH}.run | |||||
DEV_TOOLS_PACKAGE=x86_ubuntu_os_devtoolset_package | |||||
export DRIVER_URL=${SERVER_PATH}/${PRODUCT_VERSION}/${DRIVER_RUN_NAME} | |||||
export DEV_TOOLS_URL=${SERVER_PATH}/20210428/${DEV_TOOLS_PACKAGE}.zip | |||||
set +e |
@@ -0,0 +1,136 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
set -e | |||||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||||
DOWNLOAD_PATH=${PROJECT_HOME}/deps | |||||
DEP_LIB_DIR=./lib | |||||
DEP_TMP_DIR=./tmp | |||||
function extract_deps_so() | |||||
{ | |||||
echo "begin to extract .run file ........." | |||||
chmod 777 ./${DRIVER_RUN_NAME} | |||||
unzip ${DEV_TOOLS_PACKAGE}.zip | |||||
chmod -R 777 ${DEV_TOOLS_PACKAGE} | |||||
[ -n "${DEP_TMP_DIR}" ] && rm -rf "${DEP_TMP_DIR}" | |||||
./${DRIVER_RUN_NAME} --noexec --extract=${DEP_TMP_DIR}/driver | |||||
./${DEV_TOOLS_PACKAGE}/${ATC_RUN_NAME} --noexec --extract=${DEP_TMP_DIR}/atc | |||||
./${DEV_TOOLS_PACKAGE}/${ACL_RUN_NAME} --noexec --extract=${DEP_TMP_DIR}/acllib | |||||
./${DEV_TOOLS_PACKAGE}/${FWKACL_RUN_NAME} --noexec --extract=${DEP_TMP_DIR}/fwkacllib | |||||
} | |||||
function copy_so_to_target_dir() | |||||
{ | |||||
mkdir -p $DEP_LIB_DIR | |||||
mv ${DEP_TMP_DIR}/driver/driver $DEP_LIB_DIR/driver | |||||
mv ${DEP_TMP_DIR}/atc/atc $DEP_LIB_DIR/atc | |||||
mv ${DEP_TMP_DIR}/acllib/acllib $DEP_LIB_DIR/acllib | |||||
mv ${DEP_TMP_DIR}/fwkacllib/fwkacllib $DEP_LIB_DIR/fwkacllib | |||||
} | |||||
function clear_libs() | |||||
{ | |||||
[ -n "${DOWNLOAD_PATH}" ] && rm -rf "${DOWNLOAD_PATH}" | |||||
} | |||||
function download_runs() | |||||
{ | |||||
source scripts/update/deps_config.sh | |||||
echo "begin to download .run file ........." | |||||
clear_libs | |||||
mkdir -p ./ ${DOWNLOAD_PATH} | |||||
pushd "${DOWNLOAD_PATH}" >/dev/null | |||||
cd ${DOWNLOAD_PATH} | |||||
wget --user=${DEP_USER} --password=${DEP_PASSWORD} ${DRIVER_URL} | |||||
wget --user=${DEP_USER} --password=${DEP_PASSWORD} ${DEV_TOOLS_URL} | |||||
popd >/dev/null | |||||
} | |||||
function install_deps() | |||||
{ | |||||
source scripts/update/deps_config.sh | |||||
mkdir -p ./ ${DOWNLOAD_PATH} | |||||
pushd "${DOWNLOAD_PATH}" >/dev/null | |||||
cd ${DOWNLOAD_PATH} | |||||
extract_deps_so | |||||
copy_so_to_target_dir | |||||
popd >/dev/null | |||||
} | |||||
function help(){ | |||||
cat <<-EOF | |||||
Usage: ge update [OPTIONS] | |||||
update dependencies of build and test | |||||
Options: | |||||
-d, --download Download dependencies | |||||
-i, --install Install dependencies | |||||
-c, --clear Clear dependencies | |||||
-h, --help | |||||
EOF | |||||
} | |||||
function parse_args(){ | |||||
parsed_args=$(getopt -a -o dich --long download,install,clear,help -- "$@") || { | |||||
help | |||||
exit 1 | |||||
} | |||||
if [ $# -lt 1 ]; then | |||||
download_runs | |||||
install_deps | |||||
exit 1 | |||||
fi | |||||
eval set -- "$parsed_args" | |||||
while true; do | |||||
case "$1" in | |||||
-d | --download) | |||||
download_runs | |||||
;; | |||||
-i | --install) | |||||
install_deps | |||||
;; | |||||
-c | --clear) | |||||
clear_libs | |||||
;; | |||||
-h | --help) | |||||
help; exit 1; | |||||
;; | |||||
--) | |||||
shift; break; | |||||
;; | |||||
*) | |||||
help; exit 1 | |||||
;; | |||||
esac | |||||
shift | |||||
done | |||||
} | |||||
function main(){ | |||||
parse_args "$@" | |||||
} | |||||
main "$@" | |||||
set +e |
@@ -372,6 +372,7 @@ set(COMMON_FORMAT_SRC_FILES | |||||
set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | ||||
"${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | ||||
"${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | ||||
"${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" | |||||
) | ) | ||||
@@ -715,7 +716,10 @@ set(PASS_TEST_FILES | |||||
"graph/passes/mark_node_unknown_shape_pass_unittest.cc" | "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | ||||
"graph/passes/reshape_recovery_pass_unittest.cc" | "graph/passes/reshape_recovery_pass_unittest.cc" | ||||
"graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
"graph/passes/memcpy_addr_async_unittest.cc" | |||||
"graph/passes/memcpy_addr_async_unittest.cc" | |||||
"graph/passes/hccl_continuous_pass_unittest.cc" | |||||
"graph/passes/hccl_memcpy_pass_unittest.cc" | |||||
) | ) | ||||
set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
@@ -798,6 +802,8 @@ set(MULTI_PARTS_TEST_FILES | |||||
"graph/manager/run_graph_unittest.cc" | "graph/manager/run_graph_unittest.cc" | ||||
"graph/partition/dynamic_shape_partition_unittest.cc" | "graph/partition/dynamic_shape_partition_unittest.cc" | ||||
"graph/manager/graph_manager_unittest.cc" | "graph/manager/graph_manager_unittest.cc" | ||||
"graph/optimize/mem_rw_conflict_optimize_unittest.cc" | |||||
"graph/optimize/graph_optimize_unittest.cc" | |||||
"session/omg_omg_unittest.cc" | "session/omg_omg_unittest.cc" | ||||
"session/ge_api_unittest.cc" | "session/ge_api_unittest.cc" | ||||
"session/inner_session_unittest.cc" | "session/inner_session_unittest.cc" | ||||
@@ -832,6 +838,7 @@ set(HYBRID_TEST_FILES | |||||
"hybrid/executor/worker/execution_engine_unittest.cc" | "hybrid/executor/worker/execution_engine_unittest.cc" | ||||
"hybrid/model/hybrid_model_builder_unittest.cc" | "hybrid/model/hybrid_model_builder_unittest.cc" | ||||
"hybrid/node_executor/rts/rts_node_task_unittest.cc" | "hybrid/node_executor/rts/rts_node_task_unittest.cc" | ||||
"hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | |||||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | "hybrid/executor/hybrid_model_async_executor_unittest.cc" | ||||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | ||||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | ||||
@@ -0,0 +1,239 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <memory> | |||||
#include <iostream> | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/optimize/graph_optimize.h" | |||||
#include "init/gelib.h" | |||||
#include "ge/ge_api.h" | |||||
#undef private | |||||
#undef protected | |||||
using namespace std; | |||||
using namespace testing; | |||||
using namespace ge; | |||||
namespace { | |||||
const char *const kVectorCore = "VectorCore"; | |||||
const char *const kAicoreEngine = "AIcoreEngine"; | |||||
string CreateEngineConfigJson() { | |||||
GELOGI("Begin to create engine config json file."); | |||||
string base_path = PluginManager::GetPath(); | |||||
GELOGI("Base path is %s.", base_path.c_str()); | |||||
string dir_path = base_path.substr(0, base_path.rfind('/') + 1) + "plugin/nnengine/ge_config"; | |||||
string cmd = "mkdir -p " + dir_path; | |||||
system(cmd.c_str()); | |||||
string file_path = dir_path + "/engine_conf.json"; | |||||
GELOGI("Begin to write into the config file: %s.", file_path.c_str()); | |||||
ofstream ofs(file_path, ios::out); | |||||
EXPECT_EQ(!ofs, false); | |||||
ofs << "{\n" | |||||
" \"schedule_units\" : [ {\n" | |||||
" \"id\" : \"TS_1\",\n" | |||||
" \"name\" : \"1980_hwts\",\n" | |||||
" \"ex_attrs\" : \"\",\n" | |||||
" \"cal_engines\" : [\n" | |||||
" {\"id\" : \"DNN_VM_GE_LOCAL\", \"name\" : \"GE_LOCAL\", \"independent\" : false, \"attch\" : true, \"skip_assign_stream\" : true },\n" | |||||
" {\"id\" : \"AIcoreEngine\", \"name\" : \"AICORE\", \"independent\" : false, \"attch\" : false, \"skip_assign_stream\" : false}\n" | |||||
" ]\n" | |||||
" } ]\n" | |||||
"}"; | |||||
ofs.close(); | |||||
GELOGI("Json config file %s has been written.", file_path.c_str()); | |||||
return file_path; | |||||
} | |||||
void DeleteFile(const string &file_name) { | |||||
auto ret = remove(file_name.c_str()); | |||||
if (ret == 0) { | |||||
GELOGI("Delete file successfully, file:%s.", file_name.c_str()); | |||||
} | |||||
} | |||||
} | |||||
class UtestGraphOptimizeTest : public testing::Test { | |||||
protected: | |||||
void SetUp() { | |||||
config_file_ = CreateEngineConfigJson(); | |||||
} | |||||
void TearDown() { | |||||
DeleteFile(config_file_); | |||||
} | |||||
private: | |||||
string config_file_; | |||||
}; | |||||
class TestGraphOptimizerSuccess : public GraphOptimizer { | |||||
public: | |||||
~TestGraphOptimizerSuccess() override { Finalize(); } | |||||
Status Initialize(const map<string, string> &options) override { return SUCCESS; } | |||||
Status Finalize() override { return SUCCESS; } | |||||
Status OptimizeGraphPrepare(ComputeGraph& graph) override { return SUCCESS; } | |||||
Status OptimizeGraphBeforeBuild(ComputeGraph& graph) override { return SUCCESS; } | |||||
Status OptimizeOriginalGraph(ComputeGraph &graph) override { return SUCCESS; } | |||||
Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) override { return SUCCESS; } | |||||
Status OptimizeFusedGraph(ComputeGraph &graph) override { return SUCCESS; } | |||||
Status OptimizeWholeGraph(ComputeGraph &graph) override { return SUCCESS; } | |||||
Status GetAttributes(GraphOptimizerAttribute &attrs) const override { | |||||
attrs.engineName = "AIcoreEngine"; | |||||
attrs.scope = OPTIMIZER_SCOPE::ENGINE; | |||||
return SUCCESS; | |||||
} | |||||
Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) override { return SUCCESS; } | |||||
Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) override { return SUCCESS; } | |||||
Status OptimizeAfterStage1(ComputeGraph &graph) override { return SUCCESS; } | |||||
}; | |||||
class TestGraphOptimizerFail : public GraphOptimizer { | |||||
public: | |||||
~TestGraphOptimizerFail() override { Finalize(); } | |||||
Status Initialize(const map<string, string> &options) override { return SUCCESS; } | |||||
Status Finalize() override { return SUCCESS; } | |||||
Status OptimizeGraphPrepare(ComputeGraph& graph) override { return FAILED; } | |||||
Status OptimizeGraphBeforeBuild(ComputeGraph& graph) override { return FAILED; } | |||||
Status OptimizeOriginalGraph(ComputeGraph &graph) override { return FAILED; } | |||||
Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) override { return FAILED; } | |||||
Status OptimizeFusedGraph(ComputeGraph &graph) override { return FAILED; } | |||||
Status OptimizeWholeGraph(ComputeGraph &graph) override { return FAILED; } | |||||
Status GetAttributes(GraphOptimizerAttribute &attrs) const override { | |||||
attrs.engineName = "AIcoreEngine"; | |||||
attrs.scope = OPTIMIZER_SCOPE::ENGINE; | |||||
return SUCCESS; | |||||
} | |||||
Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) override { return FAILED; } | |||||
Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) override { return FAILED; } | |||||
Status OptimizeAfterStage1(ComputeGraph &graph) override { return FAILED; } | |||||
}; | |||||
TEST_F(UtestGraphOptimizeTest, test_OptimizeAfterStage1_succ) { | |||||
map<string, string> options; | |||||
Status ret = ge::GELib::Initialize(options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
EXPECT_NE(instance_ptr, nullptr); | |||||
GraphOptimizerPtr graph_opt = MakeShared<TestGraphOptimizerSuccess>(); | |||||
instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||||
GraphOptimize base_optimize; | |||||
ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
base_optimize.core_type_ = kVectorCore; | |||||
ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
ret = instance_ptr->Finalize(); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphOptimizeTest, test_OptimizeAfterStage1_fail) { | |||||
ComputeGraphPtr compute_graph = nullptr; | |||||
GraphOptimize base_optimize; | |||||
// 1. Input graph is nullptr. | |||||
Status ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
// 2. GELib is not initialized. | |||||
compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||||
ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||||
EXPECT_EQ(ret, GE_CLI_GE_NOT_INITIALIZED); | |||||
// 3. The optimizer registered with the engine returned a failure. | |||||
map<string, string> options; | |||||
ret = ge::GELib::Initialize(options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
EXPECT_NE(instance_ptr, nullptr); | |||||
GraphOptimizerPtr graph_opt = MakeShared<TestGraphOptimizerFail>(); | |||||
instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); | |||||
ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||||
EXPECT_EQ(ret, FAILED); | |||||
ret = instance_ptr->Finalize(); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphOptimizeTest, test_optimizers_succ) { | |||||
map<string, string> options; | |||||
Status ret = ge::GELib::Initialize(options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
EXPECT_NE(instance_ptr, nullptr); | |||||
GraphOptimizerPtr graph_opt = MakeShared<TestGraphOptimizerSuccess>(); | |||||
instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||||
GraphOptimize base_optimize; | |||||
ret = base_optimize.OptimizeOriginalGraph(compute_graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
ret = base_optimize.OptimizeOriginalGraphJudgeInsert(compute_graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
ret = base_optimize.OptimizeOriginalGraphForQuantize(compute_graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
ret = base_optimize.OptimizeGraphBeforeBuildForRts(compute_graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
ret = base_optimize.OptimizeWholeGraph(compute_graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
ret = instance_ptr->Finalize(); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphOptimizeTest, test_optimizers_fail) { | |||||
map<string, string> options; | |||||
Status ret = ge::GELib::Initialize(options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
EXPECT_NE(instance_ptr, nullptr); | |||||
GraphOptimizerPtr graph_opt = MakeShared<TestGraphOptimizerFail>(); | |||||
instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||||
GraphOptimize base_optimize; | |||||
ret = base_optimize.OptimizeOriginalGraph(compute_graph); | |||||
EXPECT_EQ(ret, FAILED); | |||||
ret = base_optimize.OptimizeOriginalGraphJudgeInsert(compute_graph); | |||||
EXPECT_EQ(ret, FAILED); | |||||
ret = base_optimize.OptimizeOriginalGraphForQuantize(compute_graph); | |||||
EXPECT_EQ(ret, FAILED); | |||||
ret = base_optimize.OptimizeGraphBeforeBuildForRts(compute_graph); | |||||
EXPECT_EQ(ret, FAILED); | |||||
ret = base_optimize.OptimizeWholeGraph(compute_graph); | |||||
EXPECT_EQ(ret, FAILED); | |||||
ret = instance_ptr->Finalize(); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} |
@@ -0,0 +1,150 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <cstdint> | |||||
#include <string> | |||||
#include <gtest/gtest.h> | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/optimize/graph_optimize.h" | |||||
#undef protected | |||||
#undef private | |||||
#include "../passes/graph_builder_utils.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
namespace ge { | |||||
class UTest_Graph_Mem_RW_Conflict_Optimize : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
namespace { | |||||
/* | |||||
* Data -cast - netoutput | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Readonly_Subgraph(const string subraph_name){ | |||||
auto sub_builder = ut::GraphBuilder(subraph_name); | |||||
auto data1 = sub_builder.AddNode("data1", DATA, 0,1); | |||||
auto cast = sub_builder.AddNode("cast", CAST, 1,1); | |||||
auto netoutput = sub_builder.AddNode("netoutput",NETOUTPUT, 1,1); | |||||
AttrUtils::SetInt(data1->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
AttrUtils::SetInt(netoutput->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX,0); | |||||
sub_builder.AddDataEdge(data1,0,cast,0); | |||||
sub_builder.AddDataEdge(cast,0,netoutput,0); | |||||
return sub_builder.GetGraph(); | |||||
} | |||||
/* | |||||
* const - allreduce | |||||
* \ if | |||||
* insert identity | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Readonly_ScopeWrite() { | |||||
auto builder = ut::GraphBuilder("test"); | |||||
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||||
auto ctrl_const = builder.AddNode("ctrl_const", CONSTANT, 0, 1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
auto if_node = builder.AddNode("if", IF, 1,0); | |||||
builder.AddDataEdge(const1, 0, allreduce, 0); | |||||
builder.AddDataEdge(const1, 0, if_node, 0); | |||||
builder.AddControlEdge(ctrl_const, allreduce); | |||||
auto root_graph = builder.GetGraph(); | |||||
string subgraph_name = "then_branch"; | |||||
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||||
then_branch_graph->SetParentNode(if_node); | |||||
then_branch_graph->SetParentGraph(root_graph); | |||||
if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||||
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||||
root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||||
return root_graph; | |||||
} | |||||
/* const1---allreduce const1--identity - allreduce | |||||
* / / | |||||
* var-identity--cast1 ==> var-----cast1 | |||||
* \ \ | |||||
* if if | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Identiyt_Split(){ | |||||
auto builder = ut::GraphBuilder("g1"); | |||||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||||
auto identity = builder.AddNode("identity", IDENTITY, 1, 1); | |||||
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
auto cast1 = builder.AddNode("cast1", CAST, 1, 1); | |||||
auto if_node = builder.AddNode("if", IF, 1,0); | |||||
builder.AddDataEdge(var, 0 , identity, 0); | |||||
builder.AddDataEdge(identity, 0 , allreduce, 0); | |||||
builder.AddDataEdge(identity, 0 , cast1, 0); | |||||
builder.AddDataEdge(identity, 0 , if_node, 0); | |||||
builder.AddControlEdge(const1, allreduce); | |||||
auto root_graph = builder.GetGraph(); | |||||
string subgraph_name = "then_branch"; | |||||
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||||
then_branch_graph->SetParentNode(if_node); | |||||
then_branch_graph->SetParentGraph(root_graph); | |||||
if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||||
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||||
root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||||
return root_graph; | |||||
} | |||||
/* | |||||
* mul == allreduce | |||||
* need insert identity | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_mul_1To2_ScopeWrite() { | |||||
auto builder = ut::GraphBuilder("test"); | |||||
auto mul = builder.AddNode("mul", MUL, 2,1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 2,0); | |||||
AttrUtils::SetBool(allreduce->GetOpDesc(), "_input_mutable", true); | |||||
builder.AddDataEdge(mul,0,allreduce,0); | |||||
builder.AddDataEdge(mul,0,allreduce,1); | |||||
return builder.GetGraph(); | |||||
} | |||||
} // namespace | |||||
// const -> allreduce | |||||
// const -> Identity -> allreduce | |||||
TEST(UtestGraphPassesHcclMemcpyPass, testReadonlyScopeWriteConflict) { | |||||
ComputeGraphPtr graph = BuildGraph_Readonly_ScopeWrite(); | |||||
GraphOptimize graph_optimizer; | |||||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
auto allreduce = graph->FindNode("allreduce"); | |||||
EXPECT_EQ(allreduce->GetInDataNodes().at(0)->GetType(), IDENTITY); | |||||
} | |||||
TEST(UtestGraphPassesHcclMemcpyPass, testIdentiytSplit) { | |||||
ComputeGraphPtr graph = BuildGraph_Identiyt_Split(); | |||||
GraphOptimize graph_optimizer; | |||||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
auto allreduce = graph->FindNode("allreduce"); | |||||
auto allreduce_in_node = allreduce->GetInDataNodes().at(0); | |||||
EXPECT_EQ(allreduce_in_node->GetType(), IDENTITY); | |||||
EXPECT_EQ(allreduce_in_node->GetInControlNodes().at(0)->GetType(), CONSTANT); | |||||
} | |||||
TEST(UtestGraphPassesHcclMemcpyPass, testMul_1To2_ScopeWrite) { | |||||
ComputeGraphPtr graph = BuildGraph_mul_1To2_ScopeWrite(); | |||||
EXPECT_EQ(graph->GetDirectNodesSize(), 2); | |||||
GraphOptimize graph_optimizer; | |||||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
EXPECT_EQ(graph->GetDirectNodesSize(), 3); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,79 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <cstdint> | |||||
#include <string> | |||||
#include <gtest/gtest.h> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/passes/hccl_continuous_memcpy_pass.h" | |||||
#undef protected | |||||
#undef private | |||||
#include "graph_builder_utils.h" | |||||
namespace ge { | |||||
class UtestGraphPassesHcclContinuousMemcpyPass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
namespace { | |||||
/* | |||||
* var var | |||||
* | \ | \ | |||||
* | assign | assign | |||||
* | // =======> | // | |||||
* allreduce identity | |||||
* | | | |||||
* netoutput allreduce | |||||
* | | |||||
* netoutput | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||||
auto builder = ut::GraphBuilder("test"); | |||||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||||
auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
builder.AddDataEdge(var, 0, assign, 0); | |||||
builder.AddDataEdge(var,0,allreduce,0); | |||||
builder.AddControlEdge(assign, allreduce); | |||||
return builder.GetGraph(); | |||||
} | |||||
} // namespace | |||||
// const -> allreduce | |||||
// const -> Identity -> allreduce | |||||
TEST(UtestGraphPassesHcclContinuousMemcpyPass, testInsertIdentityBeforeHccl) { | |||||
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||||
auto src_node = graph->FindNode("var"); | |||||
auto dst_node = graph->FindNode("allreduce"); | |||||
// test InsertIdentityBeforeHccl | |||||
HcclContinuousMemcpyPass hccl_continuous_memcpy_pass; | |||||
hccl_continuous_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(0)); | |||||
// check | |||||
dst_node = graph->FindNode("allreduce"); | |||||
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,80 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <cstdint> | |||||
#include <string> | |||||
#include <gtest/gtest.h> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/passes/hccl_memcpy_pass.h" | |||||
#undef protected | |||||
#undef private | |||||
#include "graph_builder_utils.h" | |||||
namespace ge { | |||||
class UtestGraphPassesHcclMemcpyPass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
namespace { | |||||
/* | |||||
* var var | |||||
* | \ | \ | |||||
* | assign | assign | |||||
* | // =======> | // | |||||
* allreduce identity | |||||
* | | | |||||
* netoutput allreduce | |||||
* | | |||||
* netoutput | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||||
auto builder = ut::GraphBuilder("test"); | |||||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||||
auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
builder.AddDataEdge(var, 0, assign, 0); | |||||
builder.AddDataEdge(var,0,allreduce,0); | |||||
builder.AddControlEdge(assign, allreduce); | |||||
return builder.GetGraph(); | |||||
} | |||||
} // namespace | |||||
// const -> allreduce | |||||
// const -> Identity -> allreduce | |||||
TEST(UtestGraphPassesHcclMemcpyPass, testInsertIdentityBeforeHccl) { | |||||
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||||
auto src_node = graph->FindNode("var"); | |||||
auto dst_node = graph->FindNode("allreduce"); | |||||
// test InsertIdentityBeforeHccl | |||||
HcclMemcpyPass hccl_memcpy_pass; | |||||
hccl_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), | |||||
dst_node->GetInDataAnchor(0)); | |||||
// check | |||||
dst_node = graph->FindNode("allreduce"); | |||||
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||||
} | |||||
} // namespace ge |
@@ -86,7 +86,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
* | | * | | ||||
* Merge | * Merge | ||||
* / \. | * / \. | ||||
* / \. | |||||
* Active / \ Active | |||||
* / \. | * / \. | ||||
* Add Sub | * Add Sub | ||||
* | \ / | | * | \ / | | ||||
@@ -96,8 +96,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
* Switch Switch | * Switch Switch | ||||
* | \ / | | * | \ / | | ||||
* | \ / | | * | \ / | | ||||
* | \ / | | |||||
* | \ / | | |||||
* | Active | | |||||
* | \ / | | |||||
* | Less | | * | Less | | ||||
* | / \ | | * | / \ | | ||||
* | / \ | | * | / \ | | ||||
@@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | ||||
} | } | ||||
const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); | |||||
const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0. | |||||
const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | ||||
switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | ||||
@@ -135,13 +135,14 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | ||||
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | ||||
const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); | |||||
const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); | |||||
const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0. | |||||
const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0. | |||||
const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | ||||
const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | ||||
const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); | const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); | ||||
const auto iteration1 = CreateNode(graph, "iteration1", NEXTITERATION, 1, 1); | |||||
const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); | const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); | ||||
output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | ||||
@@ -170,7 +171,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); | GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), iteration1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(iteration1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
} | } | ||||
TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { | TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { | ||||
@@ -28,6 +28,7 @@ | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/ge_local_context.h" | #include "graph/ge_local_context.h" | ||||
#include "graph/common/omg_util.h" | |||||
using namespace std; | using namespace std; | ||||
using namespace testing; | using namespace testing; | ||||
@@ -157,7 +158,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | ||||
AttrUtils::SetStr(merge1->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next1->GetName()); | |||||
SetNextIteration(merge1, next1); | |||||
AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | ||||
AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | ||||
@@ -0,0 +1,114 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <gmock/gmock.h> | |||||
#include <vector> | |||||
#define private public | |||||
#define protected public | |||||
#include "hybrid/executor/subgraph_context.h" | |||||
#include "hybrid/node_executor/ge_local/ge_local_node_executor.h" | |||||
#include "model/ge_root_model.h" | |||||
#undef protected | |||||
#undef private | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
using namespace hybrid; | |||||
class UtestGeLocalNodeExecutor : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() { } | |||||
}; | |||||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
op_desc->SetStreamId(0); | |||||
static int32_t index = 0; | |||||
op_desc->SetId(index++); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
TensorUtils::SetSize(tensor, 64); | |||||
vector<int64_t> input_offset; | |||||
for (int i = 0; i < in_num; i++) { | |||||
op_desc->AddInputDesc(tensor); | |||||
input_offset.emplace_back(i * 64); | |||||
} | |||||
op_desc->SetInputOffset(input_offset); | |||||
vector<int64_t> output_offset; | |||||
for (int i = 0; i < out_num; i++) { | |||||
op_desc->AddOutputDesc(tensor); | |||||
output_offset.emplace_back(in_num * 64 + i * 64); | |||||
} | |||||
op_desc->SetOutputOffset(output_offset); | |||||
op_desc->SetWorkspace({}); | |||||
op_desc->SetWorkspaceBytes({}); | |||||
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
return graph.AddNode(op_desc); | |||||
} | |||||
TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
ge_root_model->SetModelName("test_name"); | |||||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
HybridModel hybrid_model(ge_root_model); | |||||
NodePtr node = CreateNode(*graph, "noop", NOOP, 0, 0); | |||||
std::unique_ptr<NodeItem> new_node; | |||||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
NodeItem *node_item = new_node.get(); | |||||
hybrid_model.node_items_[node] = std::move(new_node); | |||||
node_item->input_start = 0; | |||||
node_item->output_start = 0; | |||||
GraphItem graph_item; | |||||
graph_item.node_items_.emplace_back(node_item); | |||||
graph_item.total_inputs_ = 0; | |||||
graph_item.total_outputs_ = 0; | |||||
GraphExecutionContext graph_context; | |||||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
ASSERT_NE(node_state, nullptr); | |||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
NodeTaskPtr task = nullptr; | |||||
GeLocalNodeExecutor node_executor; | |||||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
ASSERT_NE(task, nullptr); | |||||
ASSERT_EQ(task->UpdateArgs(*node_state->GetTaskContext()), SUCCESS); | |||||
std::function<void()> done = []() {}; | |||||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
} | |||||
} // namespace ge |