Browse Source

!13268 add SetitemTupleEliminator to item_tuple_or_list_eliminate pass

From: @huangbingjian
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
a018390e40
1 changed files with 27 additions and 3 deletions
  1. +27
    -3
      mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h

+ 27
- 3
mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h View File

@@ -191,14 +191,18 @@ class GetitemConstEliminator : public AnfVisitor {

// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z}
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, a, b, c, ...}, 0, z} => {prim::kPrimMakeTuple, z, b, c, ...}
// {prim::kPrimListSetItem, {prim::kPrimMakeList, a, b, c, ...}, 0, z} => {prim::kPrimMakeList, z, b, c, ...}
// {prim::kPrimTupleSetItem, (a, b, c, ...), 0, z} => {prim::kPrimMakeTuple, z, b, c, ...}
// {prim::kPrimListSetItem, [a, b, c, ...], 0, z} => {prim::kPrimMakeList, z, b, c, ...}
class SetitemEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsVNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimListSetItem, {IsVNode, IsVNode, IsNode})(node);

auto fg = node->func_graph();
if (fg != nullptr && z_ != nullptr) {
@@ -225,7 +229,27 @@ class SetitemEliminator : public AnfVisitor {
}

void Visit(const ValueNodePtr &vnode) override {
if (!args_.empty() && IsValueNode<Int64Imm>(vnode)) {
if (args_.empty() && IsValueNode<ValueTuple>(vnode)) {
auto tuple = GetValueNode<ValueTuplePtr>(vnode);
if (tuple != nullptr) {
args_.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &val : tuple->value()) {
auto val_node = std::make_shared<ValueNode>(val);
val_node->set_abstract(val->ToAbstract());
args_.emplace_back(val_node);
}
}
} else if (args_.empty() && IsValueNode<ValueList>(vnode)) {
auto list = GetValueNode<ValueListPtr>(vnode);
if (list != nullptr) {
args_.emplace_back(NewValueNode(prim::kPrimMakeList));
for (auto &val : list->value()) {
auto val_node = std::make_shared<ValueNode>(val);
val_node->set_abstract(val->ToAbstract());
args_.emplace_back(val_node);
}
}
} else if (!args_.empty() && IsValueNode<Int64Imm>(vnode)) {
auto idx = GetValue<int64_t>(vnode->value());
if (idx < 0) {
idx = idx + args_.size() - 1;


Loading…
Cancel
Save