Browse Source

!1206 Support params sharing in mindconverter

From: @liuchongming74
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
6c4e1a67b2
9 changed files with 130 additions and 148 deletions
  1. +16
    -55
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  2. +3
    -0
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  3. +1
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  4. +0
    -23
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  5. +25
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  6. +38
    -8
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  7. +31
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  8. +1
    -57
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  9. +15
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 16
- 55
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -163,61 +163,6 @@ class Fragment(abc.ABC):
return self._output_shape


class CodeFragment(Fragment):
"""
Manage the variables related with code generation.

For single operation type node, the variables in `CodeLine` stands for:
```python
class Module(nn.Cell):
def __init__ (self, ...):
super(Module, self).__init__()
self.<CodeLine.declared_variable_name> = <CodeLine.operation>(<CodeLine.scalar_args>,
<CodeLine.init_trainable_params>)
self.<CodeLine.trainable_params[k].param_name> = Tensor(<CodeLine.trainable_params[k].shape>,
dtype=<CodeLine._trainable_params[k].dtype>)

def construct(self, x, ...):
<CodeLine.output_var_name> = self.<CodeLine.declared_variable_name>(<CodeLine.operation_inputs>)
...
return output
```

Args:
operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values.
settings (namedTuple): Code generation setting.

"""

def __init__(self, operation, actual_args, settings, input_shape, output_shape,
trainable_params=None, trainable_weights=None):
super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args,
input_shape=input_shape, output_shape=output_shape,
settings=settings)
self._trainable_params = dict() # External weights, like Matmul.
self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d.
self._trainable_weights = trainable_weights

@property
def trainable_params(self):
"""Return the trainable parameters."""
return self._trainable_params

@property
def trainable_weights(self):
return self._trainable_weights


class ModuleFragment(Fragment):
"""Manage module type code variables."""

def __init__(self, operation, actual_args, settings, input_shape, output_shape):
super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args,
input_shape=input_shape, output_shape=output_shape,
settings=settings)


class NewFragment:
"""
Fragment definition for MindSpore code generation.
@@ -310,6 +255,12 @@ class NewFragment:
return f"{opt}[{inner_idx}]"
return opt

@staticmethod
def create_parameter(weight_shape, weight_dtype):
"""Create a parameter code line."""
return f"Parameter(Tensor(np.random.uniform(0, 1, {weight_shape}).astype(np.{weight_dtype})), " \
f"name=None)"

def __call__(self) -> Tuple[List[str], List[str]]:
"""
Define parameter rewrite function.
@@ -334,6 +285,10 @@ class NewFragment:
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value)
return init_stats, call_stats

def register_parameter(self, var, line):
"""Append a new parameter into template."""
self._code_template[var][TemplateKeywords.INIT.value].append(line)

@staticmethod
def _rewrite(var, data, template: str) -> str:
"""
@@ -353,6 +308,12 @@ class NewFragment:
data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value])
if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data:
rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value])
if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data:
rewrite_params = {
f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value][slot]
for slot in data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value]
}
rewrite_data.update(rewrite_params)
rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value])
return template.format(**{
k: str(rewrite_data[k]) for k in rewrite_data


+ 3
- 0
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -81,6 +81,9 @@ class GlobalContext(metaclass=Singleton):

self.outputs_storage = OutputStorage()

# Record weights name that used many times.
self.repeated_weights = dict()

def get_onnx_node_from_identifier(self, identifier):
"""Return an OnnxUtils defined node by its identifier."""
onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier)


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -107,6 +107,7 @@ class ExchangeMessageKeywords(Enum):
ARGS = "args"
WEIGHTS = "weights"
TRAINABLE_PARAMS = "trainable_params"
PARAMETERS_DECLARED = "parameters"


BINARY_HEADER_PYTORCH_FILE = \


+ 0
- 23
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -21,7 +21,6 @@ from mindinsight.mindconverter.graph_based_converter.generator.node_struct impor
from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope
from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value
from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr

@@ -184,28 +183,6 @@ class ModuleStruct:
"initialized": self.initialized
})

def init_module_fragment(self):
"""Init the module fragment."""
if not self.initialized:
return
# check if fragment exists in global context
op = "Module{}".format(self.pattern_id)
if op == "Module-1": # reset as Main Model's op name
op = "Model"
frag = GlobalContext().get_module_fragment(op)
if frag is not None: # use exists fragment
self._fragment = frag
else:
frag = ModuleFragment(operation=op,
actual_args=None,
input_shape=None,
output_shape=None,
settings=None)
self._fragment = frag
# set fragment pattern
self._fragment.pattern = self._node_structs
GlobalContext().add_module_fragment(op, frag)

def init_args_translator(self):
"""Initialize the Args Translator for the module."""
var_name = self.ms_var_name


+ 25
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -14,7 +14,10 @@
# ==============================================================================
"""Introduce some standard pattern into MindConverter."""

__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"]
__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern",
"USER_DEFINED_PATTERN", "user_defined_pattern"]

from collections import OrderedDict

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name

@@ -22,6 +25,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern

BUILT_IN_PATTERN = dict()
USER_DEFINED_PATTERN = OrderedDict()


def is_built_in_pattern(pattern: Pattern):
@@ -75,6 +79,26 @@ def register_pattern(ptn_name, in_degree, out_degree):
return _reg


def user_defined_pattern(pattern_name: str):
"""
Register user define pattern to MindConverter.

Args:
pattern_name (str): Pattern name.
"""

def _f(ptn):
pattern = ptn()
if not pattern:
raise ValueError("`ptn` cannot be None.")
if not pattern_name:
raise ValueError("`pattern_name` cannot be None.")
USER_DEFINED_PATTERN[pattern_name] = pattern
return ptn

return _f


@register_pattern("ConvBnClip", 1, 1)
def _conv_bn_clip():
"""Add conv-bn-clip pattern."""


+ 38
- 8
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -424,7 +424,18 @@ def _post_process_overlap(patterns) -> Dict:
return patterns


class SearchPath:
class BasePath:
"""Base class of SearchPath (auto-search) and ReplacePath (greedy-match)."""

def __init__(self, pattern, sequence: List, prev_path=None):
self.pattern = pattern
self.recursion_path = prev_path.recursion_path[:] if prev_path is not None else list()
if prev_path is not None:
self.recursion_path.append(prev_path)
self.topo_order_bef_repl = sequence


class SearchPath(BasePath):
"""
Use SearchPath to store the search path.

@@ -439,15 +450,9 @@ class SearchPath:

def __init__(self, pattern, sequence: List[BaseNode], prev_path=None,
graph=None, sub_graph_size: int = 2):
self.pattern = pattern
super(SearchPath, self).__init__(pattern, sequence, prev_path)
self.graph = copy.copy(prev_path.graph) if prev_path is not None \
else copy.copy(graph)
self.recursion_path = prev_path.recursion_path[:] \
if prev_path is not None else list()
if prev_path is not None:
self.recursion_path.append(prev_path)

self.topo_order_bef_repl = sequence
self.topo_order_aft_repl, self.inverted_index = self._create_new_order()
self.node_collection = dict()
self.hash_of_aft_repl = gen_hash_key(self.topo_order_aft_repl)
@@ -689,3 +694,28 @@ class SearchPath:
f"H: {self.heuristic_v}, G: {self.actual_v}, E: {self.evaluate_score()}"

return repr_str


class ReplacePath(BasePath):
"""Data struct of replacing path with greedy matching."""

def __init__(self, pattern, sequence: List, prev_path=None):
super(ReplacePath, self).__init__(pattern, sequence, prev_path)
self.topo_order_aft_repl = None

def replace(self, increment_idx):
"""
Greedy matching.

Args:
increment_idx (int): To deduplicate module name.
"""
src = ",".join(self.topo_order_bef_repl)
tgt = self.pattern.pattern
md_name = f"Module{increment_idx}"
src_aft_repl = src.replace(tgt, md_name)
if src != src_aft_repl:
self.pattern.module_name = md_name
self.topo_order_aft_repl = src_aft_repl.split(",")
return md_name
return None

+ 31
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -16,6 +16,7 @@
from queue import PriorityQueue, Queue
from typing import Dict, List

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import USER_DEFINED_PATTERN
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \
pattern_fuzzy_matching
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \
@@ -25,7 +26,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \
generate_pattern, find_built_in_pattern
generate_pattern, find_built_in_pattern, ReplacePath
from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError


@@ -371,6 +372,7 @@ def _build_connection(loader):
context.precursor_table[node_name] = list(node.get_precursor_dict().keys())
context.successor_table[node_name] = list(node.get_successor_dict().keys())
context.outputs_table[node_name] = node.output_name_list

# Record the model inputs count, use it to control the search algorithm.
context.has_multi_inputs = len(loader.input_nodes) > 1
dag = DagGraph(nodes=context.node_collection.copy(),
@@ -426,6 +428,28 @@ def _add_known_module_name(search_path):
return ctx


def greedy_match(topo_order, user_defined_ptn):
"""
Greedy replace topological order with given pattern by user.

Args:
topo_order (list[str]): Topological order sequence.
user_defined_ptn (dict): User defined pattern.
"""
increment_idx = 0
prev_path = None
for md_name, ptn in user_defined_ptn:
ptn = Pattern(",".join(ptn), len(ptn), -1, -1, ptn)
ptn.known_module_name = md_name
topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl
repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path)
module_name = repl_path.replace(increment_idx)
if module_name is not None:
increment_idx += 1
prev_path = repl_path
return prev_path


@SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.")
def generate_scope_name(data_loader):
"""
@@ -439,6 +463,12 @@ def generate_scope_name(data_loader):
"""
init_dag = _build_connection(data_loader)
try:
if USER_DEFINED_PATTERN:
topo_order = [node for _, node in context.node_collection.items()]
repl_path = greedy_match(topo_order, USER_DEFINED_PATTERN)
topo_order_with_scope_name_list = _retrieve_scope_name(repl_path) if repl_path else flatten_graph(init_dag)
return topo_order_with_scope_name_list

result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6)
topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag)



+ 1
- 57
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -15,14 +15,11 @@
"""Define graph entity."""
import abc
from collections import OrderedDict
from copy import deepcopy

from typing import List

from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment
from mindinsight.mindconverter.graph_based_converter.constant import NodeType, InputType
from mindinsight.mindconverter.graph_based_converter.mapper.base import Mapper
from mindinsight.mindconverter.graph_based_converter.constant import InputType
from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError


@@ -574,56 +571,3 @@ class GraphNode(abc.ABC):
ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct))

return ipt_args_settings_in_construct

def param_transform(self, mapper: Mapper, variable_name):
"""
Transform param in PyTorch operation into MindSpore.

Args:
variable_name (str): Variable name.
mapper (ONNXToMindSporeMapper): Mapper between onnx operation
and MindSpore.

Returns:
dict, transformed params.
"""
if self._node_type != NodeType.OPERATION.value:
args = deepcopy(self._args_in_code)
self._args_in_code = dict()
for arg, value in args.items():
self._args_in_code[self._get_arg_name(arg, variable_name)] = value
return CodeFragment(operation="", actual_args=args, settings=None,
input_shape=self.input_shape, output_shape=self.output_shape)

if self.transformed:
raise ValueError("Already transformed.")

params = deepcopy(self._op_params)
params.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name,
params=params,
weights=self._weight)

if ms_op:
code_fragment = CodeFragment(operation=ms_op,
actual_args=ms_params,
settings=ms_settings,
input_shape=self.input_shape,
output_shape=self.output_shape,
trainable_params=ms_weights)
else:
code_fragment = CodeFragment(operation=self._op_name,
actual_args=self._op_params,
settings=None,
input_shape=self.input_shape,
output_shape=self.output_shape,
trainable_params=self._weight)

for arg, value in code_fragment.actual_args.items():
self._args_in_code[self._get_arg_name(arg, variable_name)] = value

self.transformed = True

return code_fragment

+ 15
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -262,7 +262,6 @@ class OnnxDataLoader:
self.model = onnx_model_sim
self.graph = onnx_model_sim.graph
self.nodes = onnx_model_sim.graph.node
self.batch_size = list(input_nodes.values())[0][0]
self.input_nodes = input_nodes
self.output_nodes = output_nodes
# args for init
@@ -275,6 +274,9 @@ class OnnxDataLoader:
self.tensors_dict = {} # {tensor_name: OnnxTensor}
self.value_info_dict = {} # Not contains input and output nodes

# Record the weight names used many times.
self.repeated_weight = dict()

self.node_output_shape_dict = OrderedDict() # {node_name: [int]}

# Key is edge of ONNX ir graph, value is the corresponding precursor node.
@@ -393,6 +395,7 @@ class OnnxDataLoader:
def _parse_nodes(self):
"""Parse each onnx nodes in the model."""
nodes_topo_idx = []
record_tensors = dict()
for idx, node in enumerate(self.nodes):
if not node.name:
node.name = "_".join(node.output)
@@ -408,11 +411,21 @@ class OnnxDataLoader:
self._global_context.onnx_node_inputs[n.name].append(ipt_nd)
else:
self._global_context.onnx_node_inputs[n.name] = [ipt_nd]
if ipt_nd in self.tensors_dict:
if ipt_nd not in record_tensors:
record_tensors[ipt_nd] = [node.name]
continue
record_tensors[ipt_nd].append(node.name)
self.repeated_weight.setdefault(ipt_nd, [])

self._global_context.onnx_node_name_to_topo_idx[n.name] = idx

for k in self.repeated_weight:
self.repeated_weight[k] = record_tensors[k][:]

self._global_context.onnx_nodes_collection = self._nodes_dict
self._global_context.onnx_nodes_topo_index = nodes_topo_idx
self._global_context.repeated_weights = self.repeated_weight

def _parse_tensors(self):
"""Parse each onnx tensors in the model."""
@@ -436,8 +449,7 @@ class OnnxDataLoader:
for i, s in enumerate(shape):
if 'unk' in s:
# Have to adapt user-define axis name, e.g. 'sequence', 'batch'.
shape[i] = self.batch_size if self.batch_size is not None else 1
continue
raise ValueError(f"cannot get shape of {node_opt_name}.")
if s == "scalar":
shape = SCALAR_WITHOUT_SHAPE
continue


Loading…
Cancel
Save