/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/dynamic_single_op_reset_shape_pass.h" #include "common/ge_inner_error_codes.h" #include "graph/utils/node_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" #include "graph/debug/ge_attr_define.h" namespace ge { namespace { const int64_t kDynamicShapeDim = -2; const char *const kAICPUKernelLibName = "aicpu_tf_kernel"; } // namespace Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); std::shared_ptr instance = ge::GELib::GetInstance(); if (instance == nullptr || !instance->InitFlag()) { GELOGE(ge::GE_CLI_GE_NOT_INITIALIZED, "Run CompileNodesPass failed."); return ge::GE_CLI_GE_NOT_INITIALIZED; } for (const auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node->GetOpDesc()); // pass input node if (node->GetType() == DATA || node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { continue; } // pass output node if (node->GetType() == NETOUTPUT) { continue; } bool single_aicpu_unknown = false; if (!AttrUtils::GetBool(node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, single_aicpu_unknown) || !single_aicpu_unknown) { continue; } // pass node aicpu node. string kernel_lib_name; if (GetSupportedKernel(node, instance, kernel_lib_name) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Get kernel lib failed of node[%s].", node->GetName().c_str()); return GRAPH_FAILED; } if (kernel_lib_name != kAICPUKernelLibName) { continue; } // reset aicpu shape to unknown shape auto op_desc = node->GetOpDesc(); std::vector dynamic_shape_dims = {kDynamicShapeDim}; GeShape dynamic_shape(dynamic_shape_dims); for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { auto input_desc = op_desc->MutableInputDesc(static_cast(i)); GE_CHECK_NOTNULL(input_desc); // pass scalar input desc auto dims_ori = input_desc->GetShape().GetDims(); if (dims_ori.size() == 0) { continue; } input_desc->SetShape(dynamic_shape); } GELOGD("Reset dynamic aicpu node [%s] shape success!", node->GetName().c_str()); } GELOGD("Reset dynamic aicpu nodes shape of graph [%s] success!", graph->GetName().c_str()); return SUCCESS; } graphStatus DynamicSingleOpResetShapePass::GetSupportedKernel(const NodePtr &node, const std::shared_ptr instance, string &kernel_lib_name) { auto op_desc = node->GetOpDesc(); if (op_desc == nullptr) { GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get op %s opdesc failed", node->GetName().c_str()); return ge::GE_GRAPH_PARAM_NULLPTR; } // reset op kernel lib, find supported kernel kernel_lib_name = op_desc->GetOpKernelLibName(); if (kernel_lib_name.empty()) { (void)instance->DNNEngineManagerObj().GetDNNEngineName(node); kernel_lib_name = op_desc->GetOpKernelLibName(); if (kernel_lib_name.empty()) { GELOGE(GRAPH_FAILED, "Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str()); return GRAPH_FAILED; } } OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); if (kernel_info == nullptr) { GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get op %s ops kernel info store failed", node->GetName().c_str()); return ge::GE_GRAPH_PARAM_NULLPTR; } // begin accuracy supported check if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { // if check accuracy support failed , try to go to other engine. GELOGD("Check Accuracy Supported return not support, node name is %s. Try to go to other engine.", op_desc->GetName().c_str()); string kernel_name_origin = kernel_lib_name; OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj(); auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); for (auto it = kernel_map.begin(); it != kernel_map.end(); ++it) { string tmp_kernel_name = it->first; if (tmp_kernel_name == kernel_name_origin) { continue; } OpsKernelInfoStorePtr tmp_kernel_info = it->second; if (CheckAccuracySupport(tmp_kernel_info, instance, op_desc)) { kernel_lib_name = tmp_kernel_name; GELOGD("Find kernel lib %s support node:%s, type:%s , get kernel lib success.", tmp_kernel_name.c_str(), node->GetName().c_str(), op_desc->GetType().c_str()); return GRAPH_SUCCESS; } } GELOGE(GRAPH_FAILED, "Cannot find kernel lib support node:%s, type:%s , get kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; } bool DynamicSingleOpResetShapePass::CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, const std::shared_ptr instance, OpDescPtr &op_desc) { auto ge_desc = MakeShared(op_desc); if (ge_desc == nullptr) { GELOGE(GE_GRAPH_MEMORY_ALLOC_FAILED, "Fail to malloc op desc."); return false; } string reason; if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { return false; } return true; } } // namespace ge