Browse Source

back cast_remove_pass modify

tags/v1.3.0
wxl 4 years ago
parent
commit
79fd9d1004
3 changed files with 66 additions and 73 deletions
  1. +29
    -63
      ge/graph/passes/cast_remove_pass.cc
  2. +1
    -3
      ge/graph/passes/cast_remove_pass.h
  3. +36
    -7
      tests/ut/ge/graph/passes/cast_remove_pass_unittest.cc

+ 29
- 63
ge/graph/passes/cast_remove_pass.cc View File

@@ -21,7 +21,6 @@
#include "graph/common/transop_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/type_utils.h"
#include "init/gelib.h"

namespace ge {
Status CastRemovePass::Run(NodePtr &node) {
@@ -62,14 +61,10 @@ Status CastRemovePass::Run(NodePtr &node) {
if (!HasSameDataType(op_desc, end_op_desc, type)) {
return SUCCESS;
}
auto instance_ptr = ge::GELib::GetInstance();
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "gelib is not initilized!");
if (RemoveCast(type, nodes_to_fuse) != SUCCESS) {
return FAILED;
}

OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
return DoFuse(ops_kernel_manager, type, nodes_to_fuse);
return SUCCESS;
}

bool CastRemovePass::CheckPrecisionLoss(const std::vector<NodePtr> &nodes_to_fuse) {
@@ -100,14 +95,26 @@ bool CastRemovePass::HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op
// op1->TransData->Cast->TransposeD->Cast->TransData->op2
// change to be
// op1->TransData->TransposeD->TransData->op2
Status CastRemovePass::DoFuse(const OpsKernelManager &ops_kernel_manager,
const DataType &type,
std::vector<NodePtr> &nodes_to_fuse) {
std::vector<size_t> to_be_deleted_cast_index;
for (size_t i = 0; i < nodes_to_fuse.size(); i++) {
NodePtr node = nodes_to_fuse[i];
Status CastRemovePass::RemoveCast(DataType &type, std::vector<NodePtr> &nodes_to_fuse) {
string cast_name;
for (NodePtr &node : nodes_to_fuse) {
if (node->GetType() == CAST) {
GELOGI("CastRemovePass, remove Cast %s.", node->GetName().c_str());
cast_name = node->GetName();
if (IsolateAndDeleteNode(node, {0}) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed",
node->GetName().c_str(), node->GetType().c_str());
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str());
return FAILED;
}
}
}

if (cast_name.empty()) {
return SUCCESS;
}
for (auto &node : nodes_to_fuse) {
if (node->GetType() == CAST) {
to_be_deleted_cast_index.emplace_back(i);
continue;
}
OpDescPtr op_desc = node->GetOpDesc();
@@ -116,66 +123,25 @@ Status CastRemovePass::DoFuse(const OpsKernelManager &ops_kernel_manager,
GELOGE(FAILED, "OpDesc must not be null.");
return FAILED;
}
auto in_desc = op_desc->MutableInputDesc(0);
auto out_desc = op_desc->MutableOutputDesc(0);
auto in_desc_org_dtype = in_desc->GetDataType();
auto out_desc_org_dtype = out_desc->GetDataType();
in_desc->SetDataType(type);
out_desc->SetDataType(type);
bool is_supported = false;
string un_supported_reasons;
for (const auto &ops_kernel_store_info : ops_kernel_manager.GetAllOpsKernelInfoStores()) {
map<string, OpInfo> op_infos;
ops_kernel_store_info.second->GetAllOpsKernelInfo(op_infos);
if (op_infos.find(op_desc->GetType()) == op_infos.end()) {
continue;
}
string un_supported_reason;
is_supported = ops_kernel_store_info.second->CheckAccuracySupported(op_desc, un_supported_reason);
if (is_supported) {
break;
}
un_supported_reasons += "{op_store " + ops_kernel_store_info.first + ":" + un_supported_reason + "} ";
}
if (!is_supported) {
// if no operator_info_store supported, do nothing
in_desc->SetDataType(in_desc_org_dtype);
out_desc->SetDataType(out_desc_org_dtype);
to_be_deleted_cast_index.clear();
GELOGI("Fused Op[%s] check supported fail! Reasons is as follows: %s",
op_desc->GetName().c_str(),
un_supported_reasons.c_str());
return SUCCESS;
}

// change node name for recompile cache, will be abandoned in April
string new_node_name = cast_name + op_desc->GetName();
op_desc->SetName(new_node_name);
// add attr to changed TransData, then will be rebuild
if (!AttrUtils::SetBool(op_desc, ATTR_NEED_COMPILE, true)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s of op:%s(%s) failed",
ATTR_NEED_COMPILE.c_str(),
op_desc->GetName().c_str(),
op_desc->GetType().c_str());
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(FAILED, "Set ATTR_NEED_COMPILE Attr fail.");
return FAILED;
}
auto in_desc = op_desc->MutableInputDesc(0);
auto out_desc = op_desc->MutableOutputDesc(0);
in_desc->SetDataType(type);
out_desc->SetDataType(type);
GELOGI("CastRemovePass, change %s %s datatype to be %s.", node->GetType().c_str(), node->GetName().c_str(),
TypeUtils::DataTypeToSerialString(type).c_str());
}
return DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse);
}

Status CastRemovePass::DoRemoveCast(const std::vector<size_t> &to_be_deleted_cast_index,
std::vector<NodePtr> &nodes_to_fuse) {
for (auto &cast_idx : to_be_deleted_cast_index) {
GELOGI("CastRemovePass, remove Cast %s.", nodes_to_fuse[cast_idx]->GetName().c_str());
if (IsolateAndDeleteNode(nodes_to_fuse[cast_idx], {0}) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed when CastRemovePass %s",
nodes_to_fuse[cast_idx]->GetName().c_str(),
nodes_to_fuse[cast_idx]->GetType().c_str(),
__FUNCTION__);
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", nodes_to_fuse[cast_idx]->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}



+ 1
- 3
ge/graph/passes/cast_remove_pass.h View File

@@ -19,7 +19,6 @@

#include <vector>
#include "graph/passes/base_pass.h"
#include "opskernel_manager/ops_kernel_manager.h"

namespace ge {
class CastRemovePass : public BaseNodePass {
@@ -29,9 +28,8 @@ class CastRemovePass : public BaseNodePass {
private:
bool CheckPrecisionLoss(const std::vector<NodePtr> &nodes_to_fuse);
bool HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op_desc, DataType &type) const;
Status RemoveCast(DataType &type, std::vector<NodePtr> &nodes_to_fuse);
NodePtr GetTheEndNode(NodePtr begin_node, std::vector<NodePtr> &nodes_to_fuse);
Status DoRemoveCast(const std::vector<size_t> &to_be_deleted_cast_index, std::vector<NodePtr> &nodes_to_fuse);
Status DoFuse(const OpsKernelManager &ops_kernel_manager, const DataType &type, std::vector<NodePtr> &nodes_to_fuse);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_CAST_REMOVE_PASS_H_

+ 36
- 7
tests/ut/ge/graph/passes/cast_remove_pass_unittest.cc View File

@@ -52,6 +52,41 @@ class UtestGraphPassesCastRemovePass : public testing::Test {
};

// case1:no net_out_put_node
// TEST_F(UtestGraphPassesCastRemovePass, DoFuseProcess) {
// std::vector<NodePtr> nodes_to_fuse;

// auto builder = ut::GraphBuilder("g1");
// auto data = builder.AddNode("data", DATA, 1, 1);
// auto cast1 = builder.AddNode("cast1", CAST, 1, 1);
// cast1->GetOpDesc()->MutableOutputDesc(0)->SetDataType(DT_FLOAT16);
// auto trans = builder.AddNode("trans", TRANSPOSE, 1, 1, FORMAT_NCHW, DT_FLOAT16);
// auto cast2 = builder.AddNode("cast2", CAST, 1, 1);
// cast2->GetOpDesc()->MutableInputDesc(0)->SetDataType(DT_FLOAT16);
// auto net = builder.AddNode("netout", NETOUTPUT, 1, 1);

// builder.AddDataEdge(data, 0, cast1, 0);
// builder.AddDataEdge(cast1, 0, trans, 0);
// builder.AddDataEdge(trans, 0, cast2, 0);
// builder.AddDataEdge(cast2, 0, net, 0);
// ComputeGraphPtr compute_graph = builder.GetGraph();

// map<string, string> options;

// CastRemovePass cast_remove_pass;
// DataType type = DT_FLOAT;
// nodes_to_fuse.emplace_back(cast1);
// nodes_to_fuse.emplace_back(trans);
// nodes_to_fuse.emplace_back(cast2);
// OpsKernelManager ops_kernel_manager;
// cast_remove_pass.DoFuse(ops_kernel_manager, type, nodes_to_fuse);
// EXPECT_EQ(compute_graph->GetAllNodesSize(),5);
// std::vector<size_t> to_be_deleted_cast_index;
// to_be_deleted_cast_index.emplace_back(0);
// to_be_deleted_cast_index.emplace_back(2);
// (void)cast_remove_pass.DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse);
// EXPECT_EQ(compute_graph->GetAllNodesSize(),3);
// }

TEST_F(UtestGraphPassesCastRemovePass, DoFuseProcess) {
std::vector<NodePtr> nodes_to_fuse;

@@ -77,12 +112,6 @@ TEST_F(UtestGraphPassesCastRemovePass, DoFuseProcess) {
nodes_to_fuse.emplace_back(cast1);
nodes_to_fuse.emplace_back(trans);
nodes_to_fuse.emplace_back(cast2);
OpsKernelManager ops_kernel_manager;
cast_remove_pass.DoFuse(ops_kernel_manager, type, nodes_to_fuse);
EXPECT_EQ(compute_graph->GetAllNodesSize(),5);
std::vector<size_t> to_be_deleted_cast_index;
to_be_deleted_cast_index.emplace_back(0);
to_be_deleted_cast_index.emplace_back(2);
(void)cast_remove_pass.DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse);
cast_remove_pass.RemoveCast(type, nodes_to_fuse);
EXPECT_EQ(compute_graph->GetAllNodesSize(),3);
}

Loading…
Cancel
Save