|
|
@@ -21,6 +21,7 @@ |
|
|
|
#include "graph/common/transop_util.h" |
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "init/gelib.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
Status CastRemovePass::Run(NodePtr &node) { |
|
|
@@ -61,10 +62,14 @@ Status CastRemovePass::Run(NodePtr &node) { |
|
|
|
if (!HasSameDataType(op_desc, end_op_desc, type)) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
if (RemoveCast(type, nodes_to_fuse) != SUCCESS) { |
|
|
|
auto instance_ptr = ge::GELib::GetInstance(); |
|
|
|
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { |
|
|
|
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "gelib is not initilized!"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
|
|
|
|
OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); |
|
|
|
return DoFuse(ops_kernel_manager, type, nodes_to_fuse); |
|
|
|
} |
|
|
|
|
|
|
|
bool CastRemovePass::CheckPrecisionLoss(const std::vector<NodePtr> &nodes_to_fuse) { |
|
|
@@ -95,26 +100,14 @@ bool CastRemovePass::HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op |
|
|
|
// op1->TransData->Cast->TransposeD->Cast->TransData->op2 |
|
|
|
// change to be |
|
|
|
// op1->TransData->TransposeD->TransData->op2 |
|
|
|
Status CastRemovePass::RemoveCast(DataType &type, std::vector<NodePtr> &nodes_to_fuse) { |
|
|
|
string cast_name; |
|
|
|
for (NodePtr &node : nodes_to_fuse) { |
|
|
|
if (node->GetType() == CAST) { |
|
|
|
GELOGI("CastRemovePass, remove Cast %s.", node->GetName().c_str()); |
|
|
|
cast_name = node->GetName(); |
|
|
|
if (IsolateAndDeleteNode(node, {0}) != SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed", |
|
|
|
node->GetName().c_str(), node->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (cast_name.empty()) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
for (auto &node : nodes_to_fuse) { |
|
|
|
Status CastRemovePass::DoFuse(const OpsKernelManager &ops_kernel_manager, |
|
|
|
const DataType &type, |
|
|
|
std::vector<NodePtr> &nodes_to_fuse) { |
|
|
|
std::vector<size_t> to_be_deleted_cast_index; |
|
|
|
for (size_t i = 0; i < nodes_to_fuse.size(); i++) { |
|
|
|
NodePtr node = nodes_to_fuse[i]; |
|
|
|
if (node->GetType() == CAST) { |
|
|
|
to_be_deleted_cast_index.emplace_back(i); |
|
|
|
continue; |
|
|
|
} |
|
|
|
OpDescPtr op_desc = node->GetOpDesc(); |
|
|
@@ -123,25 +116,66 @@ Status CastRemovePass::RemoveCast(DataType &type, std::vector<NodePtr> &nodes_to |
|
|
|
GELOGE(FAILED, "OpDesc must not be null."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
auto in_desc = op_desc->MutableInputDesc(0); |
|
|
|
auto out_desc = op_desc->MutableOutputDesc(0); |
|
|
|
auto in_desc_org_dtype = in_desc->GetDataType(); |
|
|
|
auto out_desc_org_dtype = out_desc->GetDataType(); |
|
|
|
in_desc->SetDataType(type); |
|
|
|
out_desc->SetDataType(type); |
|
|
|
bool is_supported = false; |
|
|
|
string un_supported_reasons; |
|
|
|
for (const auto &ops_kernel_store_info : ops_kernel_manager.GetAllOpsKernelInfoStores()) { |
|
|
|
map<string, OpInfo> op_infos; |
|
|
|
ops_kernel_store_info.second->GetAllOpsKernelInfo(op_infos); |
|
|
|
if (op_infos.find(op_desc->GetType()) == op_infos.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
string un_supported_reason; |
|
|
|
is_supported = ops_kernel_store_info.second->CheckAccuracySupported(op_desc, un_supported_reason); |
|
|
|
if (is_supported) { |
|
|
|
break; |
|
|
|
} |
|
|
|
un_supported_reasons += "{op_store " + ops_kernel_store_info.first + ":" + un_supported_reason + "} "; |
|
|
|
} |
|
|
|
if (!is_supported) { |
|
|
|
// if no operator_info_store supported, do nothing |
|
|
|
in_desc->SetDataType(in_desc_org_dtype); |
|
|
|
out_desc->SetDataType(out_desc_org_dtype); |
|
|
|
to_be_deleted_cast_index.clear(); |
|
|
|
GELOGI("Fused Op[%s] check supported fail! Reasons is as follows: %s", |
|
|
|
op_desc->GetName().c_str(), |
|
|
|
un_supported_reasons.c_str()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
// change node name for recompile cache, will be abandoned in April |
|
|
|
string new_node_name = cast_name + op_desc->GetName(); |
|
|
|
op_desc->SetName(new_node_name); |
|
|
|
// add attr to changed TransData, then will be rebuild |
|
|
|
if (!AttrUtils::SetBool(op_desc, ATTR_NEED_COMPILE, true)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s of op:%s(%s) failed", |
|
|
|
ATTR_NEED_COMPILE.c_str(), |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
op_desc->GetName().c_str(), |
|
|
|
op_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Set ATTR_NEED_COMPILE Attr fail."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
auto in_desc = op_desc->MutableInputDesc(0); |
|
|
|
auto out_desc = op_desc->MutableOutputDesc(0); |
|
|
|
in_desc->SetDataType(type); |
|
|
|
out_desc->SetDataType(type); |
|
|
|
GELOGI("CastRemovePass, change %s %s datatype to be %s.", node->GetType().c_str(), node->GetName().c_str(), |
|
|
|
TypeUtils::DataTypeToSerialString(type).c_str()); |
|
|
|
} |
|
|
|
return DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse); |
|
|
|
} |
|
|
|
|
|
|
|
Status CastRemovePass::DoRemoveCast(const std::vector<size_t> &to_be_deleted_cast_index, |
|
|
|
std::vector<NodePtr> &nodes_to_fuse) { |
|
|
|
for (auto &cast_idx : to_be_deleted_cast_index) { |
|
|
|
GELOGI("CastRemovePass, remove Cast %s.", nodes_to_fuse[cast_idx]->GetName().c_str()); |
|
|
|
if (IsolateAndDeleteNode(nodes_to_fuse[cast_idx], {0}) != SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed when CastRemovePass %s", |
|
|
|
nodes_to_fuse[cast_idx]->GetName().c_str(), |
|
|
|
nodes_to_fuse[cast_idx]->GetType().c_str(), |
|
|
|
__FUNCTION__); |
|
|
|
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", nodes_to_fuse[cast_idx]->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|