Browse Source

Pre Merge pull request !1911 from lichun/master

pull/1911/MERGE
lichun Gitee 4 years ago
parent
commit
01e5ab4c3f
2 changed files with 40 additions and 6 deletions
  1. +20
    -6
      ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc
  2. +20
    -0
      ge/graph/passes/reshape_remove_pass.cc

+ 20
- 6
ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc View File

@@ -22,9 +22,11 @@
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/compute_graph.h"
#include "ge_local_engine/ops_kernel_store/op/op_factory.h"
#include "ge_local_engine/common/constant/constant.h"
#include "register/ops_kernel_builder_registry.h"
#include "framework/common/debug/log.h"

namespace ge {
namespace ge_local {
@@ -34,6 +36,7 @@ namespace {
const char *const kConstantOpType = "Constant";
const char *const kConstantOpAttrName = "value";
const char *const kDataOpType = "Data";
const char *const ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape";
} // namespace

GeLocalOpsKernelBuilder::~GeLocalOpsKernelBuilder() {
@@ -161,13 +164,24 @@ Status GeLocalOpsKernelBuilder::CalcConstantStrMemSize(const OpDescPtr &op_desc,
}

Status GeLocalOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) {
bool is_shape_unknown = false;
if (NodeUtils::GetNodeUnknownShapeStatus(node, is_shape_unknown) == GRAPH_SUCCESS) {
if (is_shape_unknown) {
GELOGI("op:%s is unknown shape, does not need to generate task",
node.GetName().c_str());
return SUCCESS;
bool is_in_unknown_subgraph = false;
bool forced_unknown = false;
for (const auto &node : node.GetOwnerComputeGraph()->GetDirectNode()) {
GE_CHK_GRAPH_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_in_unknown_subgraph),
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str());
if (is_in_unknown_subgraph) {
break;
}
if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, forced_unknown) && forced_unknown) {
GELOGD("node %s was marked as unknown shape.", node->GetName().c_str());
is_in_unknown_subgraph = true;
break;
}
}
if (is_in_unknown_subgraph) {
GELOGI("op:%s is in unknown shape subgraph, does not need to generate task",
node.GetName().c_str());
return SUCCESS;
}
string name = node.GetName();
string type = node.GetType();


+ 20
- 0
ge/graph/passes/reshape_remove_pass.cc View File

@@ -27,6 +27,7 @@
namespace ge {
namespace {
const int kReshapeDataIndex = 0;
const char* const ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape";
enum OpHashValue {
kReshapeType = 0,
kReformatType = 1,
@@ -45,6 +46,25 @@ Status ReshapeRemovePass::Run(NodePtr &node) {
int key = kToBeDeleteOp.find(node->GetType()) == kToBeDeleteOp.end() ? kOpNoDelete : kToBeDeleteOp[node->GetType()];
switch (key) {
case kReshapeType: {
bool is_in_unknown_shape_graph = false;
bool forced_unknown = false;
for (const auto &node : node->GetOwnerComputeGraph()->GetDirectNode()) {
GE_CHK_GRAPH_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_in_unknown_shape_graph),
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str());
if (is_in_unknown_shape_graph) {
break;
}
if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, forced_unknown) && forced_unknown) {
GELOGD("node %s was marked as unknown shape.", node->GetName().c_str());
is_in_unknown_shape_graph = true;
break;
}
}
if (is_in_unknown_shape_graph) {
GELOGI("op:%s is in unknown shape graph, can not be deleted.", node->GetName().c_str());
return SUCCESS;
}

bool is_shape_unknown = false;
if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) {
if (is_shape_unknown) {


Loading…
Cancel
Save