Browse Source

slice tensor no originshape

tags/v1.1.0
wangwenhua1@huawei.com 3 years ago
parent
commit
593ddd7c7c
4 changed files with 41 additions and 1 deletions
  1. +8
    -0
      ge/graph/common/transop_util.cc
  2. +2
    -0
      ge/graph/common/transop_util.h
  3. +29
    -0
      ge/graph/preprocess/graph_preprocess.cc
  4. +2
    -1
      ge/host_kernels/slice_kernel.cc

+ 8
- 0
ge/graph/common/transop_util.cc View File

@@ -81,5 +81,13 @@ bool TransOpUtil::CheckPrecisionLoss(const ge::NodePtr &src_node) {
return false; return false;
} }
return true; return true;

std::string TransOpUtil::TransopMapToString() {
std::string buffer;
for (auto it = transop_index_map_.begin(); it != transop_index_map_.end(); ++it) {
buffer += it->first + ",";
}
return buffer.substr(0, buffer.size() -1);
}
} }
} // namespace ge } // namespace ge

+ 2
- 0
ge/graph/common/transop_util.h View File

@@ -35,6 +35,8 @@ class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY TransOpUtil {


static bool CheckPrecisionLoss(const NodePtr &src_node); static bool CheckPrecisionLoss(const NodePtr &src_node);


static std::string TransopMapToString();

private: private:
TransOpUtil(); TransOpUtil();




+ 29
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -218,6 +218,9 @@ NodePtr CreateTransNode(const std::string &name, const std::string &node_type, c


auto index = TransOpUtil::GetTransOpDataIndex(node_type); auto index = TransOpUtil::GetTransOpDataIndex(node_type);
if (index < 0) { if (index < 0) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19025", {"situation", "reason"},
{"The trans node type[" + node_type + "]", "it must be " + TransOpUtil::TransopMapToString()});
GELOGE(INTERNAL_ERROR, "The trans node type %s does not exists", node_type.c_str()); GELOGE(INTERNAL_ERROR, "The trans node type %s does not exists", node_type.c_str());
return nullptr; return nullptr;
} }
@@ -386,6 +389,8 @@ Status RecoverTransRoadForVar(const NodePtr &var, const VarTransRoad &road) {
auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); auto trans_name = var->GetName() + "_trans_" + std::to_string(index++);
auto ret = RecoverOneTransNodeForVar(trans_name, *iter, last_node, last_node); auto ret = RecoverOneTransNodeForVar(trans_name, *iter, last_node, last_node);
if (ret != SUCCESS) { if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E15001", {"variable", "index", "type"}, {var->GetName(), std::to_string(index), iter->node_type});
GELOGE(INTERNAL_ERROR, "Failed to recover trans node for variable %s, index %d, type %s", var->GetName().c_str(), GELOGE(INTERNAL_ERROR, "Failed to recover trans node for variable %s, index %d, type %s", var->GetName().c_str(),
index, iter->node_type.c_str()); index, iter->node_type.c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
@@ -418,6 +423,9 @@ Status RecoverTransRoadForVarRef(const std::set<NodePtr> &nodes, const VarTransR
auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); auto trans_name = var->GetName() + "_trans_" + std::to_string(index++);
auto ret = RecoverOneTransNodeForVarRef(trans_name, *iter, last_node, last_node); auto ret = RecoverOneTransNodeForVarRef(trans_name, *iter, last_node, last_node);
if (ret != SUCCESS) { if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E15001", {"variable", "index", "type"}, {var->GetName(), std::to_string(index), iter->node_type});
GELOGE(INTERNAL_ERROR, "Failed to recover trans node for variable %s, index %d, type %s", GELOGE(INTERNAL_ERROR, "Failed to recover trans node for variable %s, index %d, type %s",
var->GetName().c_str(), index, iter->node_type.c_str()); var->GetName().c_str(), index, iter->node_type.c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
@@ -570,6 +578,8 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node
std::string related_node_name; std::string related_node_name;
if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) {
if (related_node_name.empty()) { if (related_node_name.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E15002", {"opname", "value", "reason"}, {data_node->GetName(), "flag", "but the value is empty"});
GELOGE(INTERNAL_ERROR, "The data node %s has switchn node flag, but the value is empty", GELOGE(INTERNAL_ERROR, "The data node %s has switchn node flag, but the value is empty",
data_node->GetName().c_str()); data_node->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
@@ -581,6 +591,9 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node
} }
} }
if (switchn_node == nullptr) { if (switchn_node == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E15002", {"opname", "value", "reason"},
{data_node->GetName(), related_node_name, "but the value is empty"});
GELOGE(INTERNAL_ERROR, "The data node %s has switchn node %s, but can not find it on the graph", GELOGE(INTERNAL_ERROR, "The data node %s has switchn node %s, but can not find it on the graph",
data_node->GetName().c_str(), related_node_name.c_str()); data_node->GetName().c_str(), related_node_name.c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
@@ -681,6 +694,10 @@ Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, No
ge::GeShape old_shape = input->GetShape(); ge::GeShape old_shape = input->GetShape();
bool support = ((old_format == FORMAT_NC1HWC0) || (old_format == FORMAT_NCHW) || (old_format == FORMAT_NHWC)); bool support = ((old_format == FORMAT_NC1HWC0) || (old_format == FORMAT_NCHW) || (old_format == FORMAT_NHWC));
if (!support) { if (!support) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19014", {"opname", "value", "reason"},
{op_desc->GetName(), "format[" + TypeUtils::FormatToSerialString(old_format) + "]",
"only support FORMAT_NC1HWC0,FORMAT_NCHW,FORMAT_NHWC"});
GELOGE(INTERNAL_ERROR, "The format [%s] is unsupported", TypeUtils::FormatToSerialString(old_format).c_str()); GELOGE(INTERNAL_ERROR, "The format [%s] is unsupported", TypeUtils::FormatToSerialString(old_format).c_str());
return FAILED; return FAILED;
} }
@@ -761,6 +778,9 @@ Status GetStorageFormatAndShape(OpDescPtr &op_desc, const GeTensorDescPtr &tenso
op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str(), op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str(),
formats::JoinToString(storage_shape).c_str()); formats::JoinToString(storage_shape).c_str());
} else { } else {
ErrorManager::GetInstance().ATCReportErrMessage(
"15003", {"opname", "format"},
{op_desc->GetName(), TypeUtils::FormatToSerialString(storage_format)});
GELOGE(PARAM_INVALID, "Update node by storage format failed, storage_shape not set. " GELOGE(PARAM_INVALID, "Update node by storage format failed, storage_shape not set. "
"node: [%s], storage_format [%s]", "node: [%s], storage_format [%s]",
op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str()); op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str());
@@ -899,9 +919,14 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node) {
// check if is_output_adjust_hw_layout is set // check if is_output_adjust_hw_layout is set
if (NeedUpdateFormatByOutputTypeParm(op_desc, index)) { if (NeedUpdateFormatByOutputTypeParm(op_desc, index)) {
if ((old_format != FORMAT_NCHW) && (old_format != FORMAT_NHWC) && (old_format != FORMAT_NC1HWC0)) { if ((old_format != FORMAT_NCHW) && (old_format != FORMAT_NHWC) && (old_format != FORMAT_NC1HWC0)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19014", {"opname", "value", "reason"},
{op_desc->GetName(), "format[" + TypeUtils::FormatToSerialString(old_format) + "]",
"only support FORMAT_NC1HWC0,FORMAT_NCHW,FORMAT_NHWC"});
GELOGE(INTERNAL_ERROR, "Format is not one of NCHW, NHWC, NC1HWC0."); GELOGE(INTERNAL_ERROR, "Format is not one of NCHW, NHWC, NC1HWC0.");
return FAILED; return FAILED;
} }

GeTensorDesc old_desc(old_shape, old_format, old_dtype); GeTensorDesc old_desc(old_shape, old_format, old_dtype);
if (ProcessNetoutputNodeFp16Nc1hwc0DynShape(old_desc, net_output_input_desc, src_node) != SUCCESS) { if (ProcessNetoutputNodeFp16Nc1hwc0DynShape(old_desc, net_output_input_desc, src_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0."); GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0.");
@@ -1034,6 +1059,10 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i
} }
bool is_acceptable = (acceptable_types.find(input_type) != acceptable_types.end()); bool is_acceptable = (acceptable_types.find(input_type) != acceptable_types.end());
if (!is_acceptable) { if (!is_acceptable) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19014", {"opname", "value", "reason"},
{op_desc->GetName(), "format[" + TypeUtils::FormatToSerialString(old_format) + "]",
"only support FORMAT_NC1HWC0,FORMAT_NCHW,FORMAT_NHWC"});
GELOGE(PARAM_INVALID, "The ref input of ref node %s[%s] must be ref node or variable, but %s[%s]isn't.", GELOGE(PARAM_INVALID, "The ref input of ref node %s[%s] must be ref node or variable, but %s[%s]isn't.",
node->GetName().c_str(), node->GetType().c_str(), input_op_desc->GetName().c_str(), node->GetName().c_str(), node->GetType().c_str(), input_op_desc->GetName().c_str(),
input_op_desc->GetType().c_str()); input_op_desc->GetType().c_str());


+ 2
- 1
ge/host_kernels/slice_kernel.cc View File

@@ -99,9 +99,10 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso
stride_vec.push_back(1); stride_vec.push_back(1);
} }
// construct tensorDesc // construct tensorDesc
ge::GeShape output_shape(output_dims);
auto attr_output_tensor_desc = attr->GetOutputDesc(0); auto attr_output_tensor_desc = attr->GetOutputDesc(0);
GeTensorDesc output_tensor_desc(attr_output_tensor_desc); GeTensorDesc output_tensor_desc(attr_output_tensor_desc);
output_tensor_desc.SetShape(output_dims);
output_tensor_desc.SetShape(output_shape);
GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc); GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc);
if (output_ptr == nullptr) { if (output_ptr == nullptr) {
GELOGW("make_shared ge::GeTensor failed, node name %s.", attr->GetName().c_str()); GELOGW("make_shared ge::GeTensor failed, node name %s.", attr->GetName().c_str());


Loading…
Cancel
Save