| @@ -15,29 +15,45 @@ | |||||
| */ | */ | ||||
| #include "graph/passes/reshape_remove_pass.h" | #include "graph/passes/reshape_remove_pass.h" | ||||
| #include <map> | |||||
| #include <string> | |||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "framework/common/types.h" | |||||
| #include "graph/passes/pass_utils.h" | #include "graph/passes/pass_utils.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const int kReshapeDataIndex = 0; | const int kReshapeDataIndex = 0; | ||||
| const int kReshapeType = 0; | |||||
| const int kReformatType = 1; | |||||
| std::map<const char *, int> kOpTypeHash = { | |||||
| {RESHAPE, kReshapeType}, | |||||
| {REFORMAT, kReformatType} | |||||
| }; | |||||
| } | } | ||||
| Status ReshapeRemovePass::Run(NodePtr &node) { | Status ReshapeRemovePass::Run(NodePtr &node) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
| if (node->GetType() != RESHAPE && node->GetType() != REFORMAT) { | |||||
| return SUCCESS; | |||||
| } | |||||
| bool is_shape_unknown = false; | |||||
| if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { | |||||
| if (is_shape_unknown) { | |||||
| GELOGI("op:%s is unknown shape, can not be deleted.", | |||||
| node->GetName().c_str()); | |||||
| switch(kOpTypeHash.find(node->GetType())) { | |||||
| case kReshapeType: | |||||
| bool is_shape_unknown = false; | |||||
| if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { | |||||
| if (is_shape_unknown) { | |||||
| GELOGI("op:%s is unknown shape, can not be deleted.", | |||||
| node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| break; | |||||
| case kReformatType: | |||||
| break; | |||||
| default: | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | |||||
| } | } | ||||
| GELOGI("Remove %s node %s", node->GetType().c_str(), node->GetName().c_str()); | GELOGI("Remove %s node %s", node->GetType().c_str(), node->GetName().c_str()); | ||||