Browse Source

!1976 Feature: V1 control flow infershape support infer multi-times

Merge pull request !1976 from zhaoxinxin/master
tags/v1.5.1
i-robot Gitee 3 years ago
parent
commit
4722b72fb4
9 changed files with 2187 additions and 1292 deletions
  1. +475
    -374
      ge/graph/passes/base_pass.cc
  2. +101
    -20
      ge/graph/passes/base_pass.h
  3. +3
    -0
      ge/graph/passes/infer_base_pass.cc
  4. +370
    -175
      ge/graph/passes/infershape_pass.cc
  5. +56
    -38
      ge/graph/passes/infershape_pass.h
  6. +16
    -0
      ge/graph/preprocess/graph_preprocess.cc
  7. +1
    -1
      tests/ut/ge/graph/passes/addn_pass_unittest.cc
  8. +903
    -523
      tests/ut/ge/graph/passes/base_pass_unittest.cc
  9. +262
    -161
      tests/ut/ge/graph/passes/infershape_pass_unittest.cc

+ 475
- 374
ge/graph/passes/base_pass.cc View File

@@ -1,374 +1,475 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/base_pass.h"

#include <queue>
#include <unordered_set>

#include "framework/common/debug/log.h"
#include "framework/common/debug/ge_log.h"
#include "graph/compute_graph.h"
#include "graph/utils/graph_utils.h"

namespace ge {
namespace {
constexpr int kMaxRePassTimes = 10000;
constexpr size_t kMaxOneInNodes = 1000;
// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later
constexpr int kMaxRecursiveDepth = 20;
struct DuringPassNodeSets {
std::unordered_set<Node *> nodes_seen;
std::unordered_set<NodePtr> nodes_deleted;
std::unordered_set<NodePtr> nodes_re_pass;
std::unordered_set<NodePtr> nodes_re_pass_immediately;
std::unordered_set<NodePtr> nodes_last;
std::unordered_set<NodePtr> nodes_suspend;
std::unordered_set<NodePtr> nodes_resume;
};

void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes,
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) {
nodes_last.clear();
for (auto &node : graph->GetDirectNode()) {
if (node == nullptr) {
continue;
}
size_t in_nums = node->GetInNodes().size();
if (in_nums == 0) {
input_edge_nodes.push_back(node);
nodes_seen.insert(node.get());
} else if (in_nums > kMaxOneInNodes) {
nodes_last.insert(node);
}
}
}

bool IsAllInNodesAlive(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_suspend) {
return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; });
}

void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
DuringPassNodeSets &during_pass_node_set) {
auto &nodes_seen = during_pass_node_set.nodes_seen;
const auto &nodes_last = during_pass_node_set.nodes_last;
const auto &nodes_suspend = during_pass_node_set.nodes_suspend;
for (auto &node : nodes) {
if (node == nullptr) {
continue;
}
if (nodes_last.count(node) != 0) {
continue;
}
if (nodes_suspend.count(node) > 0) {
GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str());
continue;
}

bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend);
bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen);
if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) {
nodes_to_pass.push_back(node);
}
}
}

void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) {
for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) {
GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str());
nodes.push_front(node);
}
during_pass_node_set.nodes_re_pass_immediately.clear();
}

void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) {
for (auto &node : during_pass_node_set.nodes_resume) {
const auto &it = during_pass_node_set.nodes_suspend.find(node);
if (it != during_pass_node_set.nodes_suspend.end()) {
during_pass_node_set.nodes_suspend.erase(node);
GELOGD("The node %s resumed by pass.", node->GetName().c_str());
nodes.push_back(node);
} else {
GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str());
}
}
during_pass_node_set.nodes_resume.clear();
}

void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name,
const std::unordered_set<NodePtr> &nodes_suspend,
const std::unordered_set<NodePtr> &nodes_resume) {
for (const auto &node : nodes_suspend) {
GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_suspend.emplace(node);
}

for (const auto &node : nodes_resume) {
GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_resume.emplace(node);
}
}

void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass,
std::unordered_set<NodePtr> &nodes_re_pass) {
for (const auto &node_to_re_pass : nodes_to_re_pass) {
if (node_to_re_pass == nullptr) {
GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(),
node->GetName().c_str(), node->GetType().c_str());
continue;
}
if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) {
GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str());
nodes_re_pass.insert(node_to_re_pass);
} else {
GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str());
}
}
}

Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNodeSets &during_pass_node_set) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
return FAILED;
}
GELOGD("Begin to run pass for node %s", node->GetName().c_str());
for (const auto &name_to_pass : names_to_passes) {
if (name_to_pass.second == nullptr) {
GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s), skip it", name_to_pass.first.c_str());
continue;
}

GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str());
name_to_pass.second->init();
auto result = name_to_pass.second->Run(node);
if (result != SUCCESS) {
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result "
"%u, the passes will be terminated immediately.",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
return result;
}

const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass,
during_pass_node_set.nodes_re_pass);

const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately,
during_pass_node_set.nodes_re_pass_immediately);

PushToSuspendNodes(during_pass_node_set, name_to_pass.first,
name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume());

const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
if (nodes_deleted_by_pass.count(node) > 0) {
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
name_to_pass.first.c_str());
break;
}
}

return SUCCESS;
}

void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) {
for (auto &name_to_pass : names_to_pass) {
name_to_pass.second->SetOption(option, "");
}
}

void ClearOption(NamesToPass names_to_pass) {
for (auto &name_to_pass : names_to_pass) {
name_to_pass.second->ClearOptions();
}
}

bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) {
if (node == nullptr) {
GELOGW("node is null");
return false;
}
if (during_pass_node_set.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
return false;
}
if (during_pass_node_set.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}

return true;
}
} // namespace

Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
return FAILED;
}
GELOGI("Prepare to isolate and delete node, name:%s, type:%s.", node->GetName().c_str(),
node->GetType().c_str());
ComputeGraphPtr graph = node->GetOwnerComputeGraph();
if (graph == nullptr) {
REPORT_INNER_ERROR("E19999", "The owner graph of node:%s must not be null.", node->GetName().c_str());
GELOGE(FAILED, "[Get][OwnerComputeGraph] failed, The owner graph of node:%s must not be null.",
node->GetName().c_str());
return FAILED;
}

AddRePassNodesWithInOut(node);

if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str());
GELOGE(FAILED, "[Isolate][Node] %s failed.", node->GetName().c_str());
return FAILED;
}

if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "call RemoveNodeWithoutRelink for node:%s failed.", node->GetName().c_str());
GELOGE(FAILED, "[Call][RemoveNodeWithoutRelink] for node:%s failed.", node->GetName().c_str());
return FAILED;
}

AddNodeDeleted(node);
return SUCCESS;
}

Status GEPass::Run(const NamesToPass &names_to_passes) {
if (graph_ == nullptr) {
REPORT_INNER_ERROR("E19999", "graph_ is nullptr, check invalid.");
GELOGE(INTERNAL_ERROR, "[Check][Param] The graph is nullptr");
return INTERNAL_ERROR;
}
if (names_to_passes.empty()) {
GELOGW("No passes input, the GEPass will do nothing");
return INTERNAL_ERROR;
}

if (depth_ > kMaxRecursiveDepth) {
GELOGE(PARAM_INVALID,
"[Check][Param] The pass for root graph %s will be terminated because too many nesting"
" levels(%d) of subgraphs, last subgraph is %s",
root_graph_->GetName().c_str(), depth_, graph_->GetName().c_str());
return PARAM_INVALID;
}

return RunPassesOneGraph(names_to_passes);
}

Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size());
std::deque<NodePtr> nodes;
DuringPassNodeSets during_pass_node_set;
GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last);
GELOGD("Start points count %zu", nodes.size());
int re_pass_times = 0;

do {
for (auto &node : during_pass_node_set.nodes_re_pass) {
nodes.push_back(node);
during_pass_node_set.nodes_seen.insert(node.get());
}
during_pass_node_set.nodes_re_pass.clear();

while (!nodes.empty()) {
NodePtr node = nodes.front();
nodes.pop_front();

(void)during_pass_node_set.nodes_re_pass.erase(node);
if (!CheckNode(node, during_pass_node_set)) {
continue;
}
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set);

auto ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u",
node->GetName().c_str(), node->GetType().c_str(), ret);
return ret;
}

bool has_sub_graph = false;
ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str());
return ret;
}

if (has_sub_graph) {
GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str());
SetFlagOption(kOptimizeAfterSubGraph, names_to_passes);
ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u",
node->GetName().c_str(), node->GetType().c_str(), ret);
return ret;
}

// There is only one option scene, so set and clear options around the `RunPasses` func.
// if there are more than one scene to set options, the `ClearOption` function
// should be called each time at the begin of the iteration
ClearOption(names_to_passes);
}

AddRepassNodes(during_pass_node_set, nodes);
AddResumeNodes(during_pass_node_set, nodes);
}

for (auto &node : during_pass_node_set.nodes_last) {
bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen);
if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) {
nodes.push_back(node);
}
}
during_pass_node_set.nodes_last.clear();
} while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes);

if (re_pass_times == kMaxRePassTimes) {
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes);
}
GELOGD("All passes runs end");

return SUCCESS;
}
Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) {
auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames();
has_sub_graph = false;
for (const auto &name : sub_graph_names) {
auto graph = root_graph_->GetSubgraph(name);
if (graph == nullptr) {
GELOGW("Can not find the sub graph %s from node %s, the pass-process will skip it",
name.c_str(), node->GetName().c_str());
continue;
}
has_sub_graph = true;
GELOGI("Begin to run passes on the sub graph %s of node %s", name.c_str(), node->GetName().c_str());
GEPass pass(graph, root_graph_, depth_ + 1);
auto ret = pass.Run(names_to_passes);
if (ret != SUCCESS) {
GELOGE(ret, "[Run][Passes] for sub graph:%s from node:%s failed", name.c_str(), node->GetName().c_str());
return ret;
}
}
return SUCCESS;
}
} // namespace ge
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/base_pass.h"
#include <queue>
#include <unordered_set>
#include "common/debug/log.h"
#include "graph/utils/graph_utils.h"
namespace ge {
namespace {
constexpr int kMaxRePassTimes = 10000;
constexpr size_t kMaxOneInNodes = 1000;
// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later
constexpr int kMaxRecursiveDepth = 20;
void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph,
GEPass::GraphLevelState &g_state) {
for (auto &node : graph->GetDirectNode()) {
if (node == nullptr) {
continue;
}
size_t in_nums = node->GetInNodes().size();
if (in_nums == 0) {
g_state.AddNodeToQueueIfNotSeen(node);
} else if (in_nums > kMaxOneInNodes) {
g_state.nodes_last.insert(node);
}
}
}
bool AnyNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) {
return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) {
return nodes_set.count(n) > 0;
});
}
bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) {
if (node == nullptr) {
GELOGW("node is null");
return false;
}
if (g_state.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
return false;
}
if (g_state.nodes_last.count(node) != 0) {
return false;
}
// all in_node seen && all in_node not suspend
if (!node->IsAllInNodesSeen(g_state.nodes_seen)) {
return false;
}
if (g_state.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}
if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}
return true;
}
void AddNextIterNodes(const NodePtr &cur_node,
std::unordered_set<NodePtr> &out_nodes_before_pass,
GEPass::GraphLevelState &g_state) {
for (auto &node : cur_node->GetOutNodes()) {
if (node == nullptr) {
continue;
}
if(out_nodes_before_pass.erase(node) == 0) {
// after pass node, new output node come up
GELOGI("New output node %s come up after pass %s.",
node->GetName().c_str(), cur_node->GetName().c_str());
}
// all in_node seen && all in_node not suspend
if (IsNodeReadyToQueue(node, g_state)) {
g_state.AddNodeToQueueIfNotSeen(node);
}
}
//
for (const auto &node : out_nodes_before_pass) {
// A-->B-->C if B was
// unlink edge may happend, add these node to queue if needed
if (node->GetInAllNodes().empty() && IsNodeReadyToQueue(node, g_state)) {
GELOGI("Node %s may lost from cur node, add to queue if not seen.",
node->GetName().c_str(), cur_node->GetName().c_str());
g_state.AddNodeToQueueIfNotSeen(node);
}
}
}
void AddImmediateRepassNodesToQueue(NodePtr &cur_node,
std::unordered_map<NodePtr, std::string> re_pass_imm_nodes_to_pass_names,
GEPass::GraphLevelState &g_state) {
for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) {
auto imme_repass_node = node_2_pass_names.first;
if (imme_repass_node == nullptr) {
GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s",
node_2_pass_names.second.c_str(),
cur_node->GetName().c_str(), cur_node->GetType().c_str());
continue;
}
if (g_state.nodes_passed.count(imme_repass_node) > 0) {
GELOGD("The node %s specified by pass %s has been passed, it will repass immediately",
imme_repass_node->GetName().c_str(), node_2_pass_names.second.c_str());
g_state.AddNodeToQueueFront(imme_repass_node);
continue;
}
GELOGW("The node %s specified by pass %s has un-passed, it will not repass immediately",
node_2_pass_names.first->GetName().c_str(), node_2_pass_names.second.c_str());
}
}
void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) {
for (auto &node : g_state.nodes_last) {
if (node->IsAllInNodesSeen(g_state.nodes_seen)) {
g_state.AddNodeToQueueIfNotSeen(node);
}
}
g_state.nodes_last.clear();
}
void AddResumeNodesToQueue(const std::unordered_map<NodePtr, std::string> resume_node_2_pass_names,
GEPass::GraphLevelState &g_state) {
// Now base pass doesnt record the order of suspend & resume, so we dont know which one come first in a node pass.
// Here if one node pass suspend and resume a node ,consider it resume that node.
// Better way to record the order, and here suspend or resume in order.
for (const auto &node_2_pass_names : resume_node_2_pass_names) {
auto node = node_2_pass_names.first;
if (g_state.nodes_suspend.erase(node) > 0) {
if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) {
g_state.nodes.push_back(node);
GELOGD("Node %s has been resumed by pass %s, and add to pass queue",
node->GetName().c_str(), node_2_pass_names.second.c_str());
}
}
}
}
void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
std::unordered_set<Node *> &nodes_seen, const std::vector<NodePtr> &nodes_to_re_pass,
GEPass::RepassLevelState &rp_state) {
for (const auto &node_to_re_pass : nodes_to_re_pass) {
if (node_to_re_pass == nullptr) {
GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(),
node->GetName().c_str(), node->GetType().c_str());
continue;
}
if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) {
if (rp_state.AddNodeToRepass(node_to_re_pass)) {
GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str());
continue;
}
GELOGD("Node %s has been added to repass queue, no need to add again.", node_to_re_pass->GetName().c_str());
} else {
GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str());
}
}
}
void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) {
for (auto &name_to_pass : names_to_pass) {
name_to_pass.second->SetOption(option, "");
}
}
void ClearOption(NamesToPass names_to_pass) {
for (auto &name_to_pass : names_to_pass) {
name_to_pass.second->ClearOptions();
}
}
} // namespace
Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map,
bool is_repass_io_immediately) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
return FAILED;
}
GELOGI("Prepare to isolate and delete node, name:%s, type:%s.", node->GetName().c_str(),
node->GetType().c_str());
ComputeGraphPtr graph = node->GetOwnerComputeGraph();
if (graph == nullptr) {
REPORT_INNER_ERROR("E19999", "The owner graph of node:%s must not be null.", node->GetName().c_str());
GELOGE(FAILED, "[Get][OwnerComputeGraph] failed, The owner graph of node:%s must not be null.",
node->GetName().c_str());
return FAILED;
}
is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node);
if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str());
GELOGE(FAILED, "[Isolate][Node] %s failed.", node->GetName().c_str());
return FAILED;
}
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "call RemoveNodeWithoutRelink for node:%s failed.", node->GetName().c_str());
GELOGE(FAILED, "[Call][RemoveNodeWithoutRelink] for node:%s failed.", node->GetName().c_str());
return FAILED;
}
AddNodeDeleted(node);
return SUCCESS;
}
Status GEPass::Run(const NamesToPass &names_to_passes) {
if (graph_ == nullptr) {
REPORT_INNER_ERROR("E19999", "graph_ is nullptr, check invalid.");
GELOGE(INTERNAL_ERROR, "[Check][Param] The graph is nullptr");
return INTERNAL_ERROR;
}
if (names_to_passes.empty()) {
GELOGW("No passes input, the GEPass will do nothing");
return INTERNAL_ERROR;
}
for (const auto &name_to_pass : names_to_passes) {
if (name_to_pass.second == nullptr) {
GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str());
return INTERNAL_ERROR;
}
}
if (depth_ > kMaxRecursiveDepth) {
GELOGE(PARAM_INVALID,
"[Check][Param] The pass for root graph %s will be terminated because too many nesting"
" levels(%d) of subgraphs, last subgraph is %s",
root_graph_->GetName().c_str(), depth_, graph_->GetName().c_str());
return PARAM_INVALID;
}
return RunPassesOneGraph(names_to_passes);
// todo debug mode is on, find first node in topo order which is not passed. and give a warning
}
void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) {
for (auto &name_to_pass : names_to_pass) {
name_to_pass.second->OnStartPassGraph(graph);
}
}
Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) {
std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names;
for (auto &name_to_pass : names_to_passes) {
name_to_pass.second->init();
auto ret = name_to_pass.second->OnSuspendNodesLeaked();
if (ret != SUCCESS) {
GELOGE(ret, "Internal error with OnSuspendNodesLeaked on pass %s.", name_to_pass.first.c_str());
return ret;
}
for (const auto &resume_node : name_to_pass.second->GetNodesResume()){
resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ",");
}
}
AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state);
return SUCCESS;
}
Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size());
NotifyPassGraphStart(graph_, names_to_passes);
GraphLevelState g_state;
g_state.re_pass_times = 0;
GetAllNodesNoInputEdge(graph_, g_state);
GELOGD("Start points count %zu", g_state.nodes.size());
do {
if (!g_state.nodes_suspend.empty()) {
auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state);
if (ret != SUCCESS) {
// log inside upper function
return ret;
}
if (g_state.nodes.empty()) {
GELOGE(INTERNAL_ERROR, "There are some suspended nodes leaked and no pass resume them.");
return INTERNAL_ERROR;
}
}
auto ret = RunPassesGraphRepass(names_to_passes, g_state);
if (ret != SUCCESS) {
return ret;
}
} while (!g_state.nodes_suspend.empty());
return SUCCESS;
}
Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) {
RepassLevelState rp_state;
do {
for (auto &node : rp_state.nodes_re_pass) {
if (rp_state.nodes_re_pass_set.count(node) > 0) {
GELOGD("Add node %s to queue for re-pass", node->GetName().c_str());
g_state.AddNodeToQueue(node);
}
}
rp_state.ClearRepass();
while (!g_state.nodes.empty()) {
auto node = g_state.PopFront();
if (g_state.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
continue;
}
rp_state.EraseNodeFromRepass(node);
g_state.nodes_seen.insert(node.get());
// collect out nodes before pass
std::unordered_set<NodePtr> out_nodes_before_pass;
for (const auto &out_node : node->GetOutNodes()) {
out_nodes_before_pass.insert(out_node);
}
auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(),
node->GetType().c_str(), ret);
return ret;
}
AddNextIterNodes(node, out_nodes_before_pass, g_state);
}
AddLastNodesToQueue(g_state);
} while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes);
if (g_state.re_pass_times == kMaxRePassTimes) {
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes);
}
GELOGD("All passes runs end");
return SUCCESS;
}
Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) {
auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames();
has_sub_graph = false;
for (const auto &name : sub_graph_names) {
auto graph = root_graph_->GetSubgraph(name);
if (graph == nullptr) {
GELOGW("Can not find the sub graph %s from node %s, the pass-process will skip it",
name.c_str(), node->GetName().c_str());
continue;
}
has_sub_graph = true;
GELOGI("Begin to run passes on the sub graph %s of node %s", name.c_str(), node->GetName().c_str());
GEPass pass(graph, root_graph_, depth_ + 1);
auto ret = pass.Run(names_to_passes);
if (ret != SUCCESS) {
GELOGE(ret, "[Run][Passes] for sub graph:%s from node:%s failed", name.c_str(), node->GetName().c_str());
return ret;
}
}
return SUCCESS;
}
Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes,
GraphLevelState &g_state, RepassLevelState &rp_state) {
auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(),
node->GetType().c_str(), ret);
return ret;
}
bool has_sub_graph = false;
ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str());
return ret;
}
if (has_sub_graph) {
GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str());
SetFlagOption(kOptimizeAfterSubGraph, names_to_passes);
ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(),
node->GetType().c_str(), ret);
return ret;
}
// There is only one option scene, so set and clear options around the `RunPasses` func.
// if there are more than one scene to set options, the `ClearOption` function
// should be called each time at the begin of the iteration
ClearOption(names_to_passes);
}
return SUCCESS;
}
Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state,
RepassLevelState &rp_state) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
return FAILED;
}
GELOGD("Begin to run pass for node %s", node->GetName().c_str());
for (const auto &name_to_pass : names_to_passes) {
GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str());
name_to_pass.second->init();
auto result = name_to_pass.second->Run(node);
if (result != SUCCESS) {
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result "
"%u, the passes will be terminated immediately.",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
return result;
}
if (name_to_pass.second->GetNodesDeleted().count(node) > 0) {
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
name_to_pass.first.c_str());
break;
}
}
g_state.nodes_passed.insert(node);
std::unordered_map<NodePtr, std::string> re_pass_imm_nodes_to_pass_names;
std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names;
// if muti psss repass one same node, it will add to queue many times, so collect and duplicate
for (const auto &name_to_pass : names_to_passes) {
PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen,
name_to_pass.second->GetNodesNeedRePass(),
rp_state);
// collect imm_node && resume_node among these passes
for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()){
re_pass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ",");
}
for (const auto &resume_node : name_to_pass.second->GetNodesResume()){
resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ",");
}
for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) {
GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(),
name_to_pass.first.c_str());
g_state.nodes_suspend.insert(suspend_node);
}
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
}
AddImmediateRepassNodesToQueue(node, re_pass_imm_nodes_to_pass_names, g_state);
AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state);
return SUCCESS;
}
} // namespace ge

+ 101
- 20
ge/graph/passes/base_pass.h View File

@@ -22,7 +22,6 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>

#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"
#include "framework/common/types.h" #include "framework/common/types.h"
#include "graph/compute_graph.h" #include "graph/compute_graph.h"
@@ -40,6 +39,7 @@ enum NodePassOption {
}; };


class BaseNodePass { class BaseNodePass {
// todo comments
public: public:
/// ///
/// Optimize on one node. the function can add nodes to the graph, change /// Optimize on one node. the function can add nodes to the graph, change
@@ -51,7 +51,7 @@ class BaseNodePass {


virtual ~BaseNodePass() = default; virtual ~BaseNodePass() = default;


const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; }
const std::vector<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; }


const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; }


@@ -61,23 +61,32 @@ class BaseNodePass {


const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; } const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; }


virtual Status OnSuspendNodesLeaked() { return SUCCESS; }

void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; }


void ClearOptions() { options_.clear(); } void ClearOptions() { options_.clear(); }


void init() { void init() {
nodes_need_re_pass_.clear(); nodes_need_re_pass_.clear();
nodes_deleted_.clear();
nodes_need_re_pass_immediately_.clear(); nodes_need_re_pass_immediately_.clear();
nodes_deleted_.clear();
nodes_suspend_.clear(); nodes_suspend_.clear();
nodes_resume_.clear(); nodes_resume_.clear();
} }


virtual void OnStartPassGraph(const ComputeGraphPtr &graph) {
current_graph_name_ = graph->GetName();
}

protected: protected:
Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map);
const string &GetCurrentGraphName() const {
return current_graph_name_;
}
Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map, bool is_repass_io_immediately = false);


Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map) {
return IsolateAndDeleteNode(node, std::vector<int>(io_map));
Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map, bool is_repass_io_immediately = false) {
return IsolateAndDeleteNode(node, std::vector<int>(io_map), is_repass_io_immediately);
} }


/// ///
@@ -86,7 +95,7 @@ class BaseNodePass {
/// optimized by other passes, call this function. /// optimized by other passes, call this function.
/// @param node /// @param node
/// ///
void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); }
void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.emplace_back(node); }


/// ///
/// Add a node to be optimized immediately again. If you add a new node to the graph, or /// Add a node to be optimized immediately again. If you add a new node to the graph, or
@@ -101,14 +110,30 @@ class BaseNodePass {
/// @param node /// @param node
/// ///
void AddRePassNodesWithInOut(const NodePtr &node) { void AddRePassNodesWithInOut(const NodePtr &node) {
auto in_nodes = node->GetInNodes();
for (auto &in_node : in_nodes) {
AddRePassNode(in_node);
}
AddRePassNode(node); AddRePassNode(node);
auto out_nodes = node->GetOutNodes(); auto out_nodes = node->GetOutNodes();
for (auto &out_node : out_nodes) { for (auto &out_node : out_nodes) {
AddRePassNode(out_node); AddRePassNode(out_node);
} }
}

///
/// Add a node and it's input/output data nodes to be optimized immediately again.
/// @param node
///
void AddImmediateRePassNodesWithInOut(const NodePtr &node) {
auto in_nodes = node->GetInNodes(); auto in_nodes = node->GetInNodes();
for (auto &in_node : in_nodes) { for (auto &in_node : in_nodes) {
AddRePassNode(in_node);
AddImmediateRePassNode(in_node);
}
AddImmediateRePassNode(node);
auto out_nodes = node->GetOutNodes();
for (auto &out_node : out_nodes) {
AddImmediateRePassNode(out_node);
} }
} }


@@ -123,34 +148,27 @@ class BaseNodePass {
void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); }


/// ///
/// If you suspend a node from the graph, especially following node. The remain
/// iterate passes will stop process on the suspend node(if it can be
/// If you postpone a node from the graph, especially following node. The remain
/// iterate passes will stop process on the postpone node(if it can be
/// reached by edge connections) till the last one. Obviously it is a waste of /// reached by edge connections) till the last one. Obviously it is a waste of
/// time. You can add the suspend nodes by calling this function, to stop the
/// time. You can add the postpone nodes by calling this function, to stop the
/// next iterations. /// next iterations.
/// @param node /// @param node
/// ///
void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); }


///
/// If you resume a node from the graph, especially following node. The remain
/// iterate passes will continue process on the resume node(if it can be
/// reached by edge connections) till the last one.
/// You can add the resume nodes by calling this function, to resume the
/// next iterations.
/// @param node
///
void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); }


bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } bool OptionExists(NodePassOption option) { return options_.count(option) > 0; }


private: private:
std::unordered_set<NodePtr> nodes_need_re_pass_;
std::vector<NodePtr> nodes_need_re_pass_;
std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; std::unordered_set<NodePtr> nodes_need_re_pass_immediately_;
std::unordered_set<NodePtr> nodes_deleted_; std::unordered_set<NodePtr> nodes_deleted_;
std::unordered_set<NodePtr> nodes_suspend_; std::unordered_set<NodePtr> nodes_suspend_;
std::unordered_set<NodePtr> nodes_resume_; std::unordered_set<NodePtr> nodes_resume_;
std::map<NodePassOption, std::string> options_; std::map<NodePassOption, std::string> options_;
std::string current_graph_name_;
}; };


using NamesToPass = std::vector<std::pair<std::string, BaseNodePass *>>; using NamesToPass = std::vector<std::pair<std::string, BaseNodePass *>>;
@@ -160,12 +178,75 @@ class GEPass {
explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {}
virtual ~GEPass() = default; virtual ~GEPass() = default;
Status Run(const NamesToPass &names_to_passes); Status Run(const NamesToPass &names_to_passes);
/*
* todo
* OneGraph: nodes_deleted, nodes_seen, nodes_passed, nodes_suspended
* RePass: nodes_re_pass
* GraphOneTime: nodes_last
* NodeOneTime: nodes_re_pass_immediately, nodes_resume
*/
struct GraphLevelState {
std::unordered_set<NodePtr> nodes_deleted;
std::unordered_set<Node *> nodes_seen;
std::unordered_set<NodePtr> nodes_passed;
std::unordered_set<NodePtr> nodes_suspend;
std::unordered_set<NodePtr> nodes_last;
std::deque<NodePtr> nodes;
int re_pass_times;

void AddNodeToQueueFront(NodePtr node) {
nodes_seen.insert(node.get());
nodes.emplace_front(std::move(node));
}

void AddNodeToQueue(NodePtr node) {
nodes_seen.insert(node.get());
nodes.emplace_back(std::move(node));
}
void AddNodeToQueueIfNotSeen(NodePtr node) {
if (nodes_seen.insert(node.get()).second) {
nodes.emplace_back(std::move(node));
}
}
NodePtr PopFront() {
NodePtr node = nodes.front();
nodes.pop_front();
return node;
}
};
struct RepassLevelState {
std::vector<NodePtr> nodes_re_pass;
std::unordered_set<NodePtr> nodes_re_pass_set;
bool AddNodeToRepass(NodePtr node) {
if (!nodes_re_pass_set.insert(node).second) {
return false;
}
nodes_re_pass.emplace_back(node);
return true;
}
void EraseNodeFromRepass(NodePtr node) {
nodes_re_pass_set.erase(node);
}
void ClearRepass() {
nodes_re_pass_set.clear();
nodes_re_pass.clear();
}
};
struct GraphOneTimeLevelState {
std::unordered_set<NodePtr> nodes_last;
};


private: private:
GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth)
: graph_(graph), root_graph_(root_graph), depth_(depth) {} : graph_(graph), root_graph_(root_graph), depth_(depth) {}
Status RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes,
GraphLevelState &g_state, RepassLevelState &rp_state);
Status RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state);
Status RunPassesOneGraph(const NamesToPass &names_to_passes); Status RunPassesOneGraph(const NamesToPass &names_to_passes);
Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph);
Status RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state,
RepassLevelState &rp_state);
Status HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state);
ComputeGraphPtr graph_; ComputeGraphPtr graph_;
ComputeGraphPtr root_graph_; ComputeGraphPtr root_graph_;
int depth_; int depth_;


+ 3
- 0
ge/graph/passes/infer_base_pass.cc View File

@@ -86,6 +86,9 @@ bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; }
void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) { void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) {
// need passed_nodes set to solve the problem that multi-input operators do repass in advance. // need passed_nodes set to solve the problem that multi-input operators do repass in advance.
// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. // when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes.
for (const auto &node_ele : changed_nodes) {
AddImmediateRePassNode(node_ele);
}
} }


graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {


+ 370
- 175
ge/graph/passes/infershape_pass.cc View File

@@ -1,175 +1,370 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/infershape_pass.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "analyzer/analyzer.h"
#include "framework/common/util.h"
#include "graph/shape_refiner.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "common/omg_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"

namespace ge {

void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
desc_str += "[";
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)desc->GetShapeRange(shape_range);
for (const auto &pair : shape_range) {
desc_str += "{";
desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
desc_str += "},";
}
desc_str += "]";
shape_range.clear();
(void)desc->GetOriginShapeRange(shape_range);
for (const auto &pair : shape_range) {
desc_str += ",{";
desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
desc_str += "},";
}
}

std::string GetInTensorInfoWithString(const ge::NodePtr &node) {
ge::OpDescPtr op_desc = node->GetOpDesc();
std::stringstream ss;
ss << "{";
int32_t in_idx = 0;
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
if (input_desc == nullptr) {
in_idx++;
continue;
}
if (in_idx > 0) {
ss << " ";
}
ss << "input_" << in_idx << " " << "tensor: [";
ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
string range_str;
SerialShapeRange(input_desc, range_str);
ss << "(shape_range:" << range_str << ")]";
in_idx++;
}
return ss.str();
}

Status InferShapePass::Run(NodePtr &node) {
// kOptimizeAfterSubGraph exist means after subgraph
auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph));
if (ret != GRAPH_SUCCESS) {
// select INFERSHAPE failed info
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
auto root_graph = ge::GraphUtils::FindRootGraph(graph);
GE_CHECK_NOTNULL(root_graph);
analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(),
analyzer::INFER_SHAPE, node, "InferShapeFailed!"};
(void)Analyzer::GetInstance()->DoAnalyze(analyze_info);
(void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(),
root_graph->GetGraphID());

REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s",
node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str());
GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s",
node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str());
return GE_GRAPH_INFERSHAPE_FAILED;
}

GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node));
bool need_repass = false;
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass);
if (has_attr) {
if (!OptionExists(kOptimizeAfterSubGraph)) {
return SUCCESS;
}
if (need_repass) {
AddImmediateRePassNode(node);
GELOGD("Node %s need repass immediately.", node->GetName().c_str());
} else {
// clear attr on while
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
}
}
return SUCCESS;
}

Status InferShapePass::RePassLoopNode(const NodePtr &node) {
const auto RePassNode = [&](const std::set<std::string> &re_pass_types) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str());
if (re_pass_types.count(node_type) > 0) {
AddImmediateRePassNode(n);
(void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false);
GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str());
}
}
return SUCCESS;
};

const auto ExProcNode = [&](const std::set<std::string> &proc_types,
const std::function<void(InferShapePass *, NodePtr)> &proc_func,
const std::string &info) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str());
if (proc_types.count(node_type) > 0) {
proc_func(this, n);
GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str());
}
}
return SUCCESS;
};

std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type),
"[Get][OriginalType] of node:%s failed.", node->GetName().c_str());
if (kNextIterationOpTypes.count(node_type) > 0) {
return RePassNode(kMergeOpTypes); // Re-Pass Merge
}

if (kMergeOpTypes.count(node_type) > 0) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return RePassNode(kSwitchOpTypes); // Re-Pass Switch
}
return SUCCESS;
}

if (kSwitchOpTypes.count(node_type) > 0) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit
} else {
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit
}
}

return SUCCESS;
}
} // namespace ge
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*+
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/infershape_pass.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "analyzer/analyzer.h"
#include "framework/common/util.h"
#include "graph/shape_refiner.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "common/omg_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "external/graph/operator_factory.h"
namespace ge {
namespace {
constexpr int kSwitchExitAnchorIndex = 0;
constexpr int kSwitchPredAnchorIndex = 1;
void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
desc_str += "[";
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)desc->GetShapeRange(shape_range);
for (const auto &pair : shape_range) {
desc_str += "{";
desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
desc_str += "},";
}
desc_str += "]";
shape_range.clear();
(void)desc->GetOriginShapeRange(shape_range);
for (const auto &pair : shape_range) {
desc_str += ",{";
desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
desc_str += "},";
}
}
void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) {
dst->SetOriginShape(src->GetOriginShape());
dst->SetShape(src->GetShape());
dst->SetDataType(src->GetDataType());
dst->SetOriginDataType(src->GetOriginDataType());
vector<pair<int64_t, int64_t>> src_shape_range;
src->GetShapeRange(src_shape_range);
dst->SetShapeRange(src_shape_range);
dst->SetOriginShapeRange(src_shape_range);
ge::TensorUtils::SetRealDimCnt(*dst, static_cast<uint32_t>(src->GetShape().GetDims().size()));
}
} // namespace
std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const {
std::stringstream ss;
ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),";
ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),";
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),";
ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),";
ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),";
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),";
string range_str;
SerialShapeRange(tensor_desc, range_str);
ss << "(shape_range:" << range_str << ")";
return ss.str();
}
Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) {
if (node->GetType() != SWITCH) {
return SUCCESS;
}
auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex);
GE_CHECK_NOTNULL(pred_node);
if (pred_node->GetType() != LOOPCOND) {
return SUCCESS;
}
for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) {
GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(),
anchor_2_node.second->GetType().c_str());
auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()];
if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) {
suspend_nodes.nodes.push(anchor_2_node.second);
AddNodeSuspend(anchor_2_node.second);
}
}
return SUCCESS;
}
Status InferShapePass::Infer(NodePtr &node) {
auto ret = InferShapeAndType(node);
if (ret != GRAPH_SUCCESS) {
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
auto root_graph = ge::GraphUtils::FindRootGraph(graph);
GE_CHECK_NOTNULL(root_graph);
analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(),
analyzer::INFER_SHAPE, node, "InferShapeFailed!"};
(void)Analyzer::GetInstance()->DoAnalyze(analyze_info);
(void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(),
root_graph->GetGraphID());
REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed", node->GetName().c_str(),
node->GetType().c_str());
GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed", node->GetName().c_str(),
node->GetType().c_str());
return GE_GRAPH_INFERSHAPE_FAILED;
}
return SUCCESS;
}
graphStatus InferShapePass::InferShapeAndType(NodePtr &node) {
auto ret = SuspendV1LoopExitNodes(node);
if (ret != SUCCESS) {
GELOGE(ret, "Suspend V1 loop exit nodes failed.");
return ret;
}
bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
auto opdesc = node->GetOpDesc();
if (node->Verify() != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str());
return GRAPH_FAILED;
}
Operator op = OpDescUtils::CreateOperatorFromNode(node);
if (!is_unknown_graph) {
auto inference_context = ShapeRefiner::CreateInferenceContext(node);
GE_CHECK_NOTNULL(inference_context);
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
op.SetInferenceContext(inference_context);
}
graphStatus status = CallInferShapeFunc(node, op);
if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) {
// node like netoutput return param_invalid, but valid ?
return GE_GRAPH_INFERSHAPE_FAILED;
}
UpdateCurNodeOutputDesc(node);
if (!is_unknown_graph) {
auto ctx_after_infer = op.GetInferenceContext();
if (ctx_after_infer != nullptr) {
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
ctx_after_infer->GetMarks().size());
ShapeRefiner::PushToContextMap(node, ctx_after_infer);
}
}
}
return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS;
}
void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) {
auto op_desc = node->GetOpDesc();
for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
GE_IF_BOOL_EXEC(output_tensor == nullptr, continue);
GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(),
output_tensor->SetOriginShape(output_tensor->GetShape()));
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims()
.size()));
output_tensor->SetOriginDataType(output_tensor->GetDataType());
// set output origin shape range
std::vector<std::pair<int64_t, int64_t>> range;
(void)output_tensor->GetShapeRange(range);
output_tensor->SetOriginShapeRange(range);
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
}
}
bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) {
// check shape range
vector<std::pair<int64_t, int64_t>> src_shape_range;
vector<std::pair<int64_t, int64_t>> dst_shape_range;
src->GetShapeRange(src_shape_range);
dst->GetShapeRange(dst_shape_range);
if (src_shape_range.size() != dst_shape_range.size()) {
GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(),
dst_shape_range.size());
return false;
}
for (size_t i = 0; i < src_shape_range.size(); ++i) {
if (src_shape_range[i].first != dst_shape_range[i].first ||
src_shape_range[i].second != dst_shape_range[i].second) {
GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.",
i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second);
return false;
}
}
// check shape
auto src_shape = src->GetShape();
auto dst_shape = dst->GetShape();
if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() ||
src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) {
GELOGD(
"Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; "
"Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.",
src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(),
TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(),
dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str());
return false;
}
return true;
}
graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
changed = !SameTensorDesc(src, dst);
// refresh src itself
src->SetOriginShape(src->GetShape());
src->SetOriginDataType(src->GetDataType());
TensorUtils::SetRealDimCnt(*src, static_cast<uint32_t>(src->GetOriginShape().GetDims().size()));
vector<pair<int64_t, int64_t>> src_shape_range;
src->GetShapeRange(src_shape_range);
src->SetOriginShapeRange(src_shape_range);
if (!changed) {
GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update.");
return SUCCESS;
}
UpdateShapeAndDType(src, dst);
GELOGD(
"UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s."
"To dst Node: shape: [%s], datatype: %s, original datatype is %s.",
src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(),
TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str());
return SUCCESS;
}
graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) {
auto op_desc = node->GetOpDesc();
const auto &op_type = op_desc->GetType();
auto ret = op_desc->CallInferFunc(op);
if (ret == GRAPH_PARAM_INVALID) {
// Op ir no infer func, try to get infer func from operator factory
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
if (node_op.IsEmpty()) {
GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
return ret;
}
GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
node_op.BreakConnect();
if (temp_op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr.");
GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null");
return GRAPH_FAILED;
}
if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
GELOGW("InferShapeAndType UpdateInputName failed");
for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
break;
}
return GRAPH_SUCCESS;
}
}
if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
GELOGW("InferShapeAndType UpdateOutputName failed");
}
op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
ret = op_desc->CallInferFunc(op);
GELOGI("op CallInferFunc second. ret: %u", ret);
}
return ret;
}
graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) {
GELOGD("Enter update parent node shape for class branch op process");
// check sub_graph shape.If not same ,do unknown shape process
auto ref_out_tensor = src.at(0);
ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape();
for (auto &tensor : src) {
if (ref_out_tensor->GetDataType() != tensor->GetDataType()) {
REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s",
ref_out_tensor_shape.ToString().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output");
return GRAPH_FAILED;
}
auto shape = tensor->MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(),
ref_out_tensor_shape.GetShapeSize());
ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
break;
}
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
continue;
}
GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(),
ref_out_tensor_shape.GetShapeSize());
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
}
}
UpdateShapeAndDType(ref_out_tensor, dst);
return GRAPH_SUCCESS;
}
graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
GeTensorDescPtr &dst) {
// check sub_graph shape. Get max for update.
if (src.empty()) {
GELOGI("Src subgraph shape is empty.");
return SUCCESS;
}
int64_t max_size = 0;
size_t max_shape_index = 0;
auto &ref_out_tensor = src.at(0);
for (size_t j = 0; j < src.size(); ++j) {
auto &tensor = src.at(j);
if (ref_out_tensor->GetDataType() != tensor->GetDataType()) {
REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output");
GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output");
return GRAPH_FAILED;
}
auto shape = tensor->MutableShape();
int64_t size = 1;
for (auto dim : shape.GetDims()) {
if (dim != 0 && INT64_MAX / dim < size) {
REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str());
GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow");
return PARAM_INVALID;
}
size *= dim;
}
if (size > max_size) {
max_size = size;
max_shape_index = j;
}
}
UpdateShapeAndDType(src.at(max_shape_index), dst);
return GRAPH_SUCCESS;
}
Status InferShapePass::OnSuspendNodesLeaked() {
auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName());
if (iter == graphs_2_suspend_nodes_.end()) {
GELOGI("Current graph %s no suspend node.", GetCurrentGraphName().c_str());
return SUCCESS;
}
if (!iter->second.nodes.empty()) {
AddNodeResume(iter->second.PopSuspendedNode());
}
return SUCCESS;
}
} // namespace ge

+ 56
- 38
ge/graph/passes/infershape_pass.h View File

@@ -1,38 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_
#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_

#include "graph/passes/base_pass.h"

namespace ge {
class InferShapePass : public BaseNodePass {
public:
///
/// Entry of the InferShapePass optimizer
/// @param [in] graph: Input ComputeGraph
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status Run(ge::NodePtr &node) override;

private:
Status RePassLoopNode(const NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_
#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_
#include "graph/passes/infer_base_pass.h"
#include <stack>
namespace ge {
class InferShapePass : public InferBasePass {
public:
std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override;
graphStatus Infer(NodePtr &node) override;
graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;
graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override;
graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
GeTensorDescPtr &dst) override;
Status OnSuspendNodesLeaked() override;
private:
graphStatus InferShapeAndType(NodePtr &node);
graphStatus CallInferShapeFunc(NodePtr &node, Operator &op);
bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst);
void UpdateCurNodeOutputDesc(NodePtr &node);
Status SuspendV1LoopExitNodes(const NodePtr &node);
struct SuspendNodes {
std::stack<NodePtr> nodes;
std::unordered_set<NodePtr> nodes_set;
NodePtr PopSuspendedNode() {
auto top_node = nodes.top();
nodes.pop();
nodes_set.erase(top_node);
return top_node;
}
};
std::map<std::string, SuspendNodes> graphs_2_suspend_nodes_;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_

+ 16
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1999,6 +1999,22 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) {


Status GraphPrepare::InferShapeForPreprocess() { Status GraphPrepare::InferShapeForPreprocess() {
GELOGI("Start infershape for preprocess."); GELOGI("Start infershape for preprocess.");
// Prepare dummy_shape for v1 control_flow op before infershape
for (const auto &node : compute_graph_->GetAllNodes()) {
string type;
GetOriginalType(node, type);
if (type == MERGE || type == REFMERGE) {
for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
GELOGD("Prepare for infershape: update %s input_shape as dummy.", node->GetName().c_str());
NodeUtils::UpdateInputShape(*node, i, GeShape(DUMMY_SHAPE));
}
} else if (type == WHILE) {
for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
GELOGD("Prepare for infershape: update %s output_shape as dummy.", node->GetName().c_str());
NodeUtils::UpdateOutputShape(*node, i, GeShape(DUMMY_SHAPE));
}
}
}
GEPass ge_passes(compute_graph_); GEPass ge_passes(compute_graph_);
NamesToPass names_to_passes; NamesToPass names_to_passes;
AssertPass assert_pass; AssertPass assert_pass;


+ 1
- 1
tests/ut/ge/graph/passes/addn_pass_unittest.cc View File

@@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) {
AddNPass *addn_pass = nullptr; AddNPass *addn_pass = nullptr;
NamesToPass names_to_pass; NamesToPass names_to_pass;
names_to_pass.emplace_back("Test", addn_pass); names_to_pass.emplace_back("Test", addn_pass);
EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
} }


TEST(UtestGraphPassesAddnPass, null_graph) { TEST(UtestGraphPassesAddnPass, null_graph) {


+ 903
- 523
tests/ut/ge/graph/passes/base_pass_unittest.cc
File diff suppressed because it is too large
View File


+ 262
- 161
tests/ut/ge/graph/passes/infershape_pass_unittest.cc View File

@@ -1,161 +1,262 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#define protected public
#define private public
#include "graph/passes/infershape_pass.h"

#include "graph/utils/tensor_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/operator_factory.h"
#include "graph/operator_reg.h"
#include "graph_builder_utils.h"

using namespace std;
using namespace testing;
namespace ge {
class UtestGraphInfershapePass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);

GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
TensorUtils::SetSize(tensor, 512);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(1024);
}
op_desc->SetInputOffset(input_offset);

vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(1024);
}
op_desc->SetOutputOffset(output_offset);

op_desc->SetWorkspace({});
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");

const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; };
op_desc->AddInferFunc(stub_func);
op_desc->AddInferFormatFunc(stub_func);
op_desc->AddVerifierFunc(stub_func);

return graph.AddNode(op_desc);
}

TEST_F(UtestGraphInfershapePass, infershape_pass_failed) {
GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
string type = "AddN";
auto addn_op_desc = std::make_shared<OpDesc>("AddN", type);
addn_op_desc->AddInputDesc(ge_tensor_desc);
addn_op_desc->AddOutputDesc(ge_tensor_desc);
auto graph = std::make_shared<ComputeGraph>("test");
auto addn_node = std::make_shared<Node>(addn_op_desc, graph);
addn_node->Init();

InferShapePass infershape_pass;
EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED);
}

TEST_F(UtestGraphInfershapePass, delete_need_infer_again) {
auto graph = std::make_shared<ComputeGraph>("test");

auto no_op_desc = std::make_shared<OpDesc>("No", "NoOp");
auto no_op_node = graph->AddNode(no_op_desc);
AttrUtils::SetBool(no_op_desc, "_need_infer_again", false);

InferShapePass infershape_pass;
infershape_pass.options_[kOptimizeAfterSubGraph] = "yes";
EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS);
}

TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) {
/*******************************************************************************
* Exit Identify
* \ / \.
* \ / \.
* Switch Add
* / | |
* / | |
* / | |
* LoopCond | |
* \ | |
* \ | |
* \ | |
* Less | |
* \ | NextIteration
* \ | |
* \ | |
* Merge <---------|
* |
* |
* Enter
******************************************************************************/
auto graph = std::make_shared<ComputeGraph>("test_infer_shape");
auto data1 = CreateNode(*graph, "data", DATA, 1, 1);
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1);
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2);
auto less1 = CreateNode(*graph, "less", LESS, 2, 1);
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1);
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2);
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1);
auto add1 = CreateNode(*graph, "add", ADD, 2, 1);
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1);
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1);
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);

GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0));
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0));

GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));

GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0));

GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0));

GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1));
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));

GEPass ge_passes(graph);
NamesToPass names_to_passes;
InferShapePass infer_shape_pass;
names_to_passes.emplace_back("InferShapePass", &infer_shape_pass);

EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS);
}
} // namespace ge
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#define protected public
#define private public
#include "graph/passes/infershape_pass.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/operator_factory.h"
#include "graph/operator_reg.h"
#include "graph_builder_utils.h"
using namespace std;
using namespace testing;
namespace ge {
class UtestGraphInfershapePass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);
GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
TensorUtils::SetSize(tensor, 512);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(1024);
}
op_desc->SetInputOffset(input_offset);
vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(1024);
}
op_desc->SetOutputOffset(output_offset);
op_desc->SetWorkspace({});
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");
const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; };
op_desc->AddInferFunc(stub_func);
op_desc->AddInferFormatFunc(stub_func);
op_desc->AddVerifierFunc(stub_func);
return graph.AddNode(op_desc);
}
TEST_F(UtestGraphInfershapePass, infershape_pass_failed) {
GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
string type = "AddN";
auto addn_op_desc = std::make_shared<OpDesc>("AddN", type);
addn_op_desc->AddInputDesc(ge_tensor_desc);
addn_op_desc->AddOutputDesc(ge_tensor_desc);
auto graph = std::make_shared<ComputeGraph>("test");
auto addn_node = std::make_shared<Node>(addn_op_desc, graph);
addn_node->Init();
InferShapePass infershape_pass;
EXPECT_EQ(infershape_pass.Run(addn_node), GRAPH_FAILED);
}
TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) {
/*******************************************************************************
* Exit Identify
* \ / \.
* \ / \.
* Switch Add
* / | |
* / | |
* / | |
* LoopCond | |
* \ | |
* \ | |
* \ | |
* Less | |
* \ | NextIteration
* \ | |
* \ | |
* Merge <---------|
* |
* |
* Enter
******************************************************************************/
auto graph = std::make_shared<ComputeGraph>("test_infer_shape");
auto data1 = CreateNode(*graph, "data", DATA, 1, 1);
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1);
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2);
auto less1 = CreateNode(*graph, "less", LESS, 2, 1);
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1);
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2);
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1);
auto add1 = CreateNode(*graph, "add", ADD, 2, 1);
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1);
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1);
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0));
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0));
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0));
GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0));
GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1));
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));
GEPass ge_passes(graph);
NamesToPass names_to_passes;
InferShapePass infer_shape_pass;
names_to_passes.emplace_back("InferShapePass", &infer_shape_pass);
EXPECT_EQ(infer_shape_pass.Run(switch1), SUCCESS);
auto suspend_nodes = infer_shape_pass.GetNodesSuspend();
auto exit_node = graph->FindNode("exit");
EXPECT_EQ(suspend_nodes.count(exit_node), 1);
infer_shape_pass.OnSuspendNodesLeaked();
auto resume_nodes = infer_shape_pass.GetNodesResume();
EXPECT_EQ(resume_nodes.count(exit_node), 1);
}
TEST_F(UtestGraphInfershapePass, update_tensordesc_when_changed) {
GeTensorDesc src_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
GeTensorDesc dst_ge_tensor_desc(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
GeTensorDescPtr src_tensor_desc_ptr = std::make_shared<GeTensorDesc>(src_ge_tensor_desc);
GeTensorDescPtr dst_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc);
InferShapePass infershape_pass;
bool changed = false;
infershape_pass.UpdateTensorDesc(src_tensor_desc_ptr, dst_tensor_desc_ptr, changed);
EXPECT_EQ(changed, true);
EXPECT_EQ(dst_tensor_desc_ptr->GetShape().GetDims(), std::vector<int64_t>({1, 2, 3, 4}));
}
TEST_F(UtestGraphInfershapePass, update_tensordesc_when_not_changed) {
GeTensorDesc src_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
GeTensorDescPtr src_tensor_desc_ptr = std::make_shared<GeTensorDesc>(src_ge_tensor_desc);
GeTensorDescPtr dst_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc);
InferShapePass infershape_pass;
bool changed = false;
infershape_pass.UpdateTensorDesc(src_tensor_desc_ptr, dst_tensor_desc_ptr, changed);
EXPECT_EQ(changed, false);
}
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_failed) {
// ref output has different dtype
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1);
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2);
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc);
InferShapePass infershape_pass;
auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr);
EXPECT_EQ(ret, GRAPH_FAILED);
}
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_get_unknown_rank) {
// ref output has different dtype
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1);
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2);
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc);
InferShapePass infershape_pass;
auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), UNKNOWN_RANK);
}
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_get_unknown_shape) {
// ref output has different dtype
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc ge_tensor_desc2(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1);
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2);
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc);
InferShapePass infershape_pass;
auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), std::vector<int64_t>({-1,2,3,4}));
// todo shape range?
}
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_failed) {
// ref output has different dtype
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1);
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2);
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc);
InferShapePass infershape_pass;
auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr},
dst_ge_tensor_desc_ptr);
EXPECT_EQ(ret, GRAPH_FAILED);
}
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_failed_shape_size_overflow) {
// ref output has different dtype
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc ge_tensor_desc2(GeShape({INT64_MAX, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1);
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2);
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc);
InferShapePass infershape_pass;
auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr},
dst_ge_tensor_desc_ptr);
EXPECT_EQ(ret, PARAM_INVALID);
}
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_success) {
// ref output has different dtype
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc ge_tensor_desc2(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT);
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1);
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2);
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc);
InferShapePass infershape_pass;
auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr},
dst_ge_tensor_desc_ptr);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), std::vector<int64_t>({2,2,3,4}));
}
} // namespace ge

Loading…
Cancel
Save