Browse Source

!1558 fix lhisi cast be deleted question when fp32 input

From: @wan_xuelei
Reviewed-by: @wqtshg,@ji_chen
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
0a4f741efe
4 changed files with 155 additions and 30 deletions
  1. +63
    -29
      ge/graph/passes/cast_remove_pass.cc
  2. +3
    -1
      ge/graph/passes/cast_remove_pass.h
  3. +1
    -0
      tests/ut/ge/CMakeLists.txt
  4. +88
    -0
      tests/ut/ge/graph/passes/cast_remove_pass_unittest.cc

+ 63
- 29
ge/graph/passes/cast_remove_pass.cc View File

@@ -21,6 +21,7 @@
#include "graph/common/transop_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/type_utils.h"
#include "init/gelib.h"

namespace ge {
Status CastRemovePass::Run(NodePtr &node) {
@@ -61,10 +62,14 @@ Status CastRemovePass::Run(NodePtr &node) {
if (!HasSameDataType(op_desc, end_op_desc, type)) {
return SUCCESS;
}
if (RemoveCast(type, nodes_to_fuse) != SUCCESS) {
auto instance_ptr = ge::GELib::GetInstance();
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "gelib is not initilized!");
return FAILED;
}
return SUCCESS;

OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
return DoFuse(ops_kernel_manager, type, nodes_to_fuse);
}

bool CastRemovePass::CheckPrecisionLoss(const std::vector<NodePtr> &nodes_to_fuse) {
@@ -95,26 +100,14 @@ bool CastRemovePass::HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op
// op1->TransData->Cast->TransposeD->Cast->TransData->op2
// change to be
// op1->TransData->TransposeD->TransData->op2
Status CastRemovePass::RemoveCast(DataType &type, std::vector<NodePtr> &nodes_to_fuse) {
string cast_name;
for (NodePtr &node : nodes_to_fuse) {
if (node->GetType() == CAST) {
GELOGI("CastRemovePass, remove Cast %s.", node->GetName().c_str());
cast_name = node->GetName();
if (IsolateAndDeleteNode(node, {0}) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed",
node->GetName().c_str(), node->GetType().c_str());
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str());
return FAILED;
}
}
}

if (cast_name.empty()) {
return SUCCESS;
}
for (auto &node : nodes_to_fuse) {
Status CastRemovePass::DoFuse(const OpsKernelManager &ops_kernel_manager,
const DataType &type,
std::vector<NodePtr> &nodes_to_fuse) {
std::vector<size_t> to_be_deleted_cast_index;
for (size_t i = 0; i < nodes_to_fuse.size(); i++) {
NodePtr node = nodes_to_fuse[i];
if (node->GetType() == CAST) {
to_be_deleted_cast_index.emplace_back(i);
continue;
}
OpDescPtr op_desc = node->GetOpDesc();
@@ -123,25 +116,66 @@ Status CastRemovePass::RemoveCast(DataType &type, std::vector<NodePtr> &nodes_to
GELOGE(FAILED, "OpDesc must not be null.");
return FAILED;
}
auto in_desc = op_desc->MutableInputDesc(0);
auto out_desc = op_desc->MutableOutputDesc(0);
auto in_desc_org_dtype = in_desc->GetDataType();
auto out_desc_org_dtype = out_desc->GetDataType();
in_desc->SetDataType(type);
out_desc->SetDataType(type);
bool is_supported = false;
string un_supported_reasons;
for (const auto &ops_kernel_store_info : ops_kernel_manager.GetAllOpsKernelInfoStores()) {
map<string, OpInfo> op_infos;
ops_kernel_store_info.second->GetAllOpsKernelInfo(op_infos);
if (op_infos.find(op_desc->GetType()) == op_infos.end()) {
continue;
}
string un_supported_reason;
is_supported = ops_kernel_store_info.second->CheckAccuracySupported(op_desc, un_supported_reason);
if (is_supported) {
break;
}
un_supported_reasons += "{op_store " + ops_kernel_store_info.first + ":" + un_supported_reason + "} ";
}
if (!is_supported) {
// if no operator_info_store supported, do nothing
in_desc->SetDataType(in_desc_org_dtype);
out_desc->SetDataType(out_desc_org_dtype);
to_be_deleted_cast_index.clear();
GELOGI("Fused Op[%s] check supported fail! Reasons is as follows: %s",
op_desc->GetName().c_str(),
un_supported_reasons.c_str());
return SUCCESS;
}

// change node name for recompile cache, will be abandoned in April
string new_node_name = cast_name + op_desc->GetName();
op_desc->SetName(new_node_name);
// add attr to changed TransData, then will be rebuild
if (!AttrUtils::SetBool(op_desc, ATTR_NEED_COMPILE, true)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s of op:%s(%s) failed",
ATTR_NEED_COMPILE.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str());
op_desc->GetName().c_str(),
op_desc->GetType().c_str());
GELOGE(FAILED, "Set ATTR_NEED_COMPILE Attr fail.");
return FAILED;
}
auto in_desc = op_desc->MutableInputDesc(0);
auto out_desc = op_desc->MutableOutputDesc(0);
in_desc->SetDataType(type);
out_desc->SetDataType(type);
GELOGI("CastRemovePass, change %s %s datatype to be %s.", node->GetType().c_str(), node->GetName().c_str(),
TypeUtils::DataTypeToSerialString(type).c_str());
}
return DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse);
}

Status CastRemovePass::DoRemoveCast(const std::vector<size_t> &to_be_deleted_cast_index,
std::vector<NodePtr> &nodes_to_fuse) {
for (auto &cast_idx : to_be_deleted_cast_index) {
GELOGI("CastRemovePass, remove Cast %s.", nodes_to_fuse[cast_idx]->GetName().c_str());
if (IsolateAndDeleteNode(nodes_to_fuse[cast_idx], {0}) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed when CastRemovePass %s",
nodes_to_fuse[cast_idx]->GetName().c_str(),
nodes_to_fuse[cast_idx]->GetType().c_str(),
__FUNCTION__);
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", nodes_to_fuse[cast_idx]->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}



+ 3
- 1
ge/graph/passes/cast_remove_pass.h View File

@@ -19,6 +19,7 @@

#include <vector>
#include "graph/passes/base_pass.h"
#include "opskernel_manager/ops_kernel_manager.h"

namespace ge {
class CastRemovePass : public BaseNodePass {
@@ -28,8 +29,9 @@ class CastRemovePass : public BaseNodePass {
private:
bool CheckPrecisionLoss(const std::vector<NodePtr> &nodes_to_fuse);
bool HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op_desc, DataType &type) const;
Status RemoveCast(DataType &type, std::vector<NodePtr> &nodes_to_fuse);
NodePtr GetTheEndNode(NodePtr begin_node, std::vector<NodePtr> &nodes_to_fuse);
Status DoRemoveCast(const std::vector<size_t> &to_be_deleted_cast_index, std::vector<NodePtr> &nodes_to_fuse);
Status DoFuse(const OpsKernelManager &ops_kernel_manager, const DataType &type, std::vector<NodePtr> &nodes_to_fuse);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_CAST_REMOVE_PASS_H_

+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -709,6 +709,7 @@ set(PASS_TEST_FILES
"graph/passes/buffer_pool_memory_pass_unittest.cc"
"graph/passes/mark_node_unknown_shape_pass_unittest.cc"
"graph/passes/reshape_recovery_pass_unittest.cc"
"graph/passes/cast_remove_pass_unittest.cc"
)

set(KERNEL_TEST_FILES


+ 88
- 0
tests/ut/ge/graph/passes/cast_remove_pass_unittest.cc View File

@@ -0,0 +1,88 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <vector>

#define protected public
#define private public
#include "graph/passes/cast_remove_pass.h"
#undef protected
#undef private

#include "anchor.h"
#include "common/debug/log.h"
#include "common/debug/memory_dumper.h"
#include "common/op/attr_value_util.h"
#include "common/types.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h"
#include "inc/pass_manager.h"
#include "graph_builder_utils.h"
#include <string>
#include <iostream>
#include <vector>
#include "opskernel_manager/ops_kernel_manager.h"
#include "omg/omg_inner_types.h"


using namespace testing;
using namespace ge;
using namespace std;

class UtestGraphPassesCastRemovePass : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

// case1:no net_out_put_node
TEST_F(UtestGraphPassesCastRemovePass, DoFuseProcess) {
std::vector<NodePtr> nodes_to_fuse;

auto builder = ut::GraphBuilder("g1");
auto data = builder.AddNode("data", DATA, 1, 1);
auto cast1 = builder.AddNode("cast1", CAST, 1, 1);
cast1->GetOpDesc()->MutableOutputDesc(0)->SetDataType(DT_FLOAT16);
auto trans = builder.AddNode("trans", TRANSPOSE, 1, 1, FORMAT_NCHW, DT_FLOAT16);
auto cast2 = builder.AddNode("cast2", CAST, 1, 1);
cast2->GetOpDesc()->MutableInputDesc(0)->SetDataType(DT_FLOAT16);
auto net = builder.AddNode("netout", NETOUTPUT, 1, 1);

builder.AddDataEdge(data, 0, cast1, 0);
builder.AddDataEdge(cast1, 0, trans, 0);
builder.AddDataEdge(trans, 0, cast2, 0);
builder.AddDataEdge(cast2, 0, net, 0);
ComputeGraphPtr compute_graph = builder.GetGraph();

map<string, string> options;

CastRemovePass cast_remove_pass;
DataType type = DT_FLOAT;
nodes_to_fuse.emplace_back(cast1);
nodes_to_fuse.emplace_back(trans);
nodes_to_fuse.emplace_back(cast2);
OpsKernelManager ops_kernel_manager;
cast_remove_pass.DoFuse(ops_kernel_manager, type, nodes_to_fuse);
EXPECT_EQ(compute_graph->GetAllNodesSize(),5);
std::vector<size_t> to_be_deleted_cast_index;
to_be_deleted_cast_index.emplace_back(0);
to_be_deleted_cast_index.emplace_back(2);
(void)cast_remove_pass.DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse);
EXPECT_EQ(compute_graph->GetAllNodesSize(),3);
}

Loading…
Cancel
Save