|
@@ -35,7 +35,7 @@ const char *const kAICPUKernelLibName = "aicpu_tf_kernel"; |
|
|
namespace ge { |
|
|
namespace ge { |
|
|
graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { |
|
|
graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { |
|
|
GE_TIMESTAMP_START(CompileNodesPass); |
|
|
GE_TIMESTAMP_START(CompileNodesPass); |
|
|
GELOGI("[CompileNodesPass]: optimize begin."); |
|
|
|
|
|
|
|
|
GELOGD("[CompileNodesPass]: optimize begin."); |
|
|
if (graph == nullptr) { |
|
|
if (graph == nullptr) { |
|
|
return GRAPH_SUCCESS; |
|
|
return GRAPH_SUCCESS; |
|
|
} |
|
|
} |
|
@@ -81,7 +81,7 @@ graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { |
|
|
GELOGE(result, "Compile op failed."); |
|
|
GELOGE(result, "Compile op failed."); |
|
|
return result; |
|
|
return result; |
|
|
} |
|
|
} |
|
|
GELOGI("[CompileNodesPass]: Optimize success."); |
|
|
|
|
|
|
|
|
GELOGD("[CompileNodesPass]: Optimize success."); |
|
|
GE_TIMESTAMP_EVENT_END(CompileNodesPass, "OptimizeStage2::ControlAttrOptimize::CompileNodesPass"); |
|
|
GE_TIMESTAMP_EVENT_END(CompileNodesPass, "OptimizeStage2::ControlAttrOptimize::CompileNodesPass"); |
|
|
return GRAPH_SUCCESS; |
|
|
return GRAPH_SUCCESS; |
|
|
} |
|
|
} |
|
@@ -111,20 +111,28 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: |
|
|
} |
|
|
} |
|
|
// begin accuracy supported check |
|
|
// begin accuracy supported check |
|
|
if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { |
|
|
if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { |
|
|
// if check accuracy support failed , try to go to aicpu engine |
|
|
|
|
|
string aicpu_kernel_lib_name = kAICPUKernelLibName; |
|
|
|
|
|
OpsKernelInfoStorePtr aicpu_kernel_info = |
|
|
|
|
|
instance->OpsKernelManagerObj().GetOpsKernelInfoStore(aicpu_kernel_lib_name); |
|
|
|
|
|
if (aicpu_kernel_info == nullptr) { |
|
|
|
|
|
GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get aicpu kernel info store failed."); |
|
|
|
|
|
return ge::GE_GRAPH_PARAM_NULLPTR; |
|
|
|
|
|
} |
|
|
|
|
|
if (!CheckAccuracySupport(aicpu_kernel_info, instance, op_desc)) { |
|
|
|
|
|
GELOGE(GRAPH_FAILED, "AICPU engine does not support node:%s, type:%s , get kernel lib failed.", |
|
|
|
|
|
node->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
|
|
return GRAPH_FAILED; |
|
|
|
|
|
|
|
|
// if check accuracy support failed , try to go to other engine. |
|
|
|
|
|
GELOGD("Check Accuracy Supported return not support, node name is %s. Try to go to other engine.", |
|
|
|
|
|
op_desc->GetName().c_str()); |
|
|
|
|
|
string kernel_name_origin = kernel_lib_name; |
|
|
|
|
|
OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj(); |
|
|
|
|
|
auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); |
|
|
|
|
|
for (auto it = kernel_map.begin(); it != kernel_map.end(); ++it) { |
|
|
|
|
|
string tmp_kernel_name = it->first; |
|
|
|
|
|
if (tmp_kernel_name == kernel_name_origin) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
OpsKernelInfoStorePtr tmp_kernel_info = it->second; |
|
|
|
|
|
if (CheckAccuracySupport(tmp_kernel_info, instance, op_desc)) { |
|
|
|
|
|
kernel_lib_name = tmp_kernel_name; |
|
|
|
|
|
GELOGD("Find kernel lib %s support node:%s, type:%s , get kernel lib success.", tmp_kernel_name.c_str(), |
|
|
|
|
|
node->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
|
|
return GRAPH_SUCCESS; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
kernel_lib_name = kAICPUKernelLibName; |
|
|
|
|
|
|
|
|
GELOGE(GRAPH_FAILED, "Cannot find kernel lib support node:%s, type:%s , get kernel lib failed.", |
|
|
|
|
|
node->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
|
|
return GRAPH_FAILED; |
|
|
} |
|
|
} |
|
|
return GRAPH_SUCCESS; |
|
|
return GRAPH_SUCCESS; |
|
|
} |
|
|
} |
|
@@ -138,8 +146,6 @@ bool CompileNodesPass::CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_ |
|
|
} |
|
|
} |
|
|
string reason; |
|
|
string reason; |
|
|
if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { |
|
|
if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { |
|
|
GELOGW("Check Accuracy Supported return not support, node name is %s, reason: %s. Try to go to AICPU engine.", |
|
|
|
|
|
op_desc->GetName().c_str(), reason.c_str()); |
|
|
|
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|