Browse Source

Optimization for ApplyTransform function

tags/v0.6.0-beta
wuyongkang 5 years ago
parent
commit
02dd305bb0
2 changed files with 19 additions and 8 deletions
  1. +14
    -6
      mindspore/ccsrc/optimizer/opt.cc
  2. +5
    -2
      tests/ut/cpp/common/py_func_graph_fetcher.h

+ 14
- 6
mindspore/ccsrc/optimizer/opt.cc View File

@@ -96,16 +96,18 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode
return result;
}

static bool isTraversable(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
static bool inline isTraversable(const AnfNodePtr &node, const AnfNodeSet &all_nodes) {
if (node->isa<CNode>() || node->isa<Parameter>()) {
return true;
return false;
}

if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
if (!all_nodes.contains(node)) {
return false;
}
return true;
}

return false;
}

@@ -128,9 +130,15 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
todo.pop_front();

// check whether this node has been matched.
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
if (node == nullptr || node->seen_ == seen) {
continue;
}

auto fg = node->func_graph();
if (!(fg != nullptr && fg->manager() != nullptr) && !isTraversable(node, all_nodes)) {
continue;
}

node->seen_ = seen;

// select nodes that this transform can be applied.


+ 5
- 2
tests/ut/cpp/common/py_func_graph_fetcher.h View File

@@ -22,6 +22,7 @@
#include "ir/primitive.h"
#include "ir/manager.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "pipeline/parse/parse_base.h"
#include "pipeline/parse/parse.h"
#include "./common.h"
@@ -47,9 +48,10 @@ class PyFuncGraphFetcher {
py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...);
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
if (doResolve_) {
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, true);
mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
mindspore::parse::ResolveAll(manager);
func_graph = BasicClone(func_graph);
}
return func_graph;
} catch (py::error_already_set& e) {
@@ -71,8 +73,9 @@ class PyFuncGraphFetcher {
py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str());
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
if (doResolve_) {
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, true);
mindspore::parse::ResolveAll(manager);
func_graph = BasicClone(func_graph);
}
return func_graph;
} catch (py::error_already_set& e) {


Loading…
Cancel
Save