| @@ -35,7 +35,7 @@ const char *const kAICPUKernelLibName = "aicpu_tf_kernel"; | |||||
| namespace ge { | namespace ge { | ||||
| graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { | graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { | ||||
| GE_TIMESTAMP_START(CompileNodesPass); | GE_TIMESTAMP_START(CompileNodesPass); | ||||
| GELOGI("[CompileNodesPass]: optimize begin."); | |||||
| GELOGD("[CompileNodesPass]: optimize begin."); | |||||
| if (graph == nullptr) { | if (graph == nullptr) { | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -81,7 +81,7 @@ graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { | |||||
| GELOGE(result, "Compile op failed."); | GELOGE(result, "Compile op failed."); | ||||
| return result; | return result; | ||||
| } | } | ||||
| GELOGI("[CompileNodesPass]: Optimize success."); | |||||
| GELOGD("[CompileNodesPass]: Optimize success."); | |||||
| GE_TIMESTAMP_EVENT_END(CompileNodesPass, "OptimizeStage2::ControlAttrOptimize::CompileNodesPass"); | GE_TIMESTAMP_EVENT_END(CompileNodesPass, "OptimizeStage2::ControlAttrOptimize::CompileNodesPass"); | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -111,20 +111,28 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||||
| } | } | ||||
| // begin accuracy supported check | // begin accuracy supported check | ||||
| if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { | if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { | ||||
| // if check accuracy support failed , try to go to aicpu engine | |||||
| string aicpu_kernel_lib_name = kAICPUKernelLibName; | |||||
| OpsKernelInfoStorePtr aicpu_kernel_info = | |||||
| instance->OpsKernelManagerObj().GetOpsKernelInfoStore(aicpu_kernel_lib_name); | |||||
| if (aicpu_kernel_info == nullptr) { | |||||
| GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get aicpu kernel info store failed."); | |||||
| return ge::GE_GRAPH_PARAM_NULLPTR; | |||||
| } | |||||
| if (!CheckAccuracySupport(aicpu_kernel_info, instance, op_desc)) { | |||||
| GELOGE(GRAPH_FAILED, "AICPU engine does not support node:%s, type:%s , get kernel lib failed.", | |||||
| node->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| // if check accuracy support failed , try to go to other engine. | |||||
| GELOGW("Check Accuracy Supported return not support, node name is %s. Try to go to other engine.", | |||||
| op_desc->GetName().c_str()); | |||||
| string kernel_name_origin = kernel_lib_name; | |||||
| OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj(); | |||||
| auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); | |||||
| for (auto it = kernel_map.begin(); it != kernel_map.end(); ++it) { | |||||
| string tmp_kernel_name = it->first; | |||||
| if (tmp_kernel_name == kernel_name_origin) { | |||||
| continue; | |||||
| } | |||||
| OpsKernelInfoStorePtr tmp_kernel_info = it->second; | |||||
| if (CheckAccuracySupport(tmp_kernel_info, instance, op_desc)) { | |||||
| kernel_lib_name = tmp_kernel_name; | |||||
| GELOGD("Find kernel lib %s support node:%s, type:%s , get kernel lib success.", tmp_kernel_name.c_str(), | |||||
| node->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } | } | ||||
| kernel_lib_name = kAICPUKernelLibName; | |||||
| GELOGE(GRAPH_FAILED, "Cannot find kernel lib support node:%s, type:%s , get kernel lib failed.", | |||||
| node->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | } | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -138,8 +146,6 @@ bool CompileNodesPass::CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_ | |||||
| } | } | ||||
| string reason; | string reason; | ||||
| if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { | if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { | ||||
| GELOGW("Check Accuracy Supported return not support, node name is %s, reason: %s. Try to go to AICPU engine.", | |||||
| op_desc->GetName().c_str(), reason.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||