diff --git a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py index 9d2fbd0c..f94db8c2 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py +++ b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py @@ -163,61 +163,6 @@ class Fragment(abc.ABC): return self._output_shape -class CodeFragment(Fragment): - """ - Manage the variables related with code generation. - - For single operation type node, the variables in `CodeLine` stands for: - ```python - class Module(nn.Cell): - def __init__ (self, ...): - super(Module, self).__init__() - self. = (, - ) - self. = Tensor(, - dtype=) - - def construct(self, x, ...): - = self.() - ... - return output - ``` - - Args: - operation (str): Operation name in MindSpore. - actual_args (dict): Actual arg values. - settings (namedTuple): Code generation setting. - - """ - - def __init__(self, operation, actual_args, settings, input_shape, output_shape, - trainable_params=None, trainable_weights=None): - super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, - input_shape=input_shape, output_shape=output_shape, - settings=settings) - self._trainable_params = dict() # External weights, like Matmul. - self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. - self._trainable_weights = trainable_weights - - @property - def trainable_params(self): - """Return the trainable parameters.""" - return self._trainable_params - - @property - def trainable_weights(self): - return self._trainable_weights - - -class ModuleFragment(Fragment): - """Manage module type code variables.""" - - def __init__(self, operation, actual_args, settings, input_shape, output_shape): - super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, - input_shape=input_shape, output_shape=output_shape, - settings=settings) - - class NewFragment: """ Fragment definition for MindSpore code generation. @@ -310,6 +255,12 @@ class NewFragment: return f"{opt}[{inner_idx}]" return opt + @staticmethod + def create_parameter(weight_shape, weight_dtype): + """Create a parameter code line.""" + return f"Parameter(Tensor(np.random.uniform(0, 1, {weight_shape}).astype(np.{weight_dtype})), " \ + f"name=None)" + def __call__(self) -> Tuple[List[str], List[str]]: """ Define parameter rewrite function. @@ -334,6 +285,10 @@ class NewFragment: ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value) return init_stats, call_stats + def register_parameter(self, var, line): + """Append a new parameter into template.""" + self._code_template[var][TemplateKeywords.INIT.value].append(line) + @staticmethod def _rewrite(var, data, template: str) -> str: """ @@ -353,6 +308,12 @@ class NewFragment: data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) + if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data: + rewrite_params = { + f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value][slot] + for slot in data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value] + } + rewrite_data.update(rewrite_params) rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value]) return template.format(**{ k: str(rewrite_data[k]) for k in rewrite_data diff --git a/mindinsight/mindconverter/graph_based_converter/common/global_context.py b/mindinsight/mindconverter/graph_based_converter/common/global_context.py index 212c72d5..e92ad953 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/global_context.py +++ b/mindinsight/mindconverter/graph_based_converter/common/global_context.py @@ -81,6 +81,9 @@ class GlobalContext(metaclass=Singleton): self.outputs_storage = OutputStorage() + # Record weights name that used many times. + self.repeated_weights = dict() + def get_onnx_node_from_identifier(self, identifier): """Return an OnnxUtils defined node by its identifier.""" onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 0e89745d..bb7507f4 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -86,6 +86,7 @@ class ExchangeMessageKeywords(Enum): ARGS = "args" WEIGHTS = "weights" TRAINABLE_PARAMS = "trainable_params" + PARAMETERS_DECLARED = "parameters" BINARY_HEADER_PYTORCH_FILE = \ diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py index 3b366f08..a0729bf9 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -21,7 +21,6 @@ from mindinsight.mindconverter.graph_based_converter.generator.node_struct impor from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation -from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr @@ -184,28 +183,6 @@ class ModuleStruct: "initialized": self.initialized }) - def init_module_fragment(self): - """Init the module fragment.""" - if not self.initialized: - return - # check if fragment exists in global context - op = "Module{}".format(self.pattern_id) - if op == "Module-1": # reset as Main Model's op name - op = "Model" - frag = GlobalContext().get_module_fragment(op) - if frag is not None: # use exists fragment - self._fragment = frag - else: - frag = ModuleFragment(operation=op, - actual_args=None, - input_shape=None, - output_shape=None, - settings=None) - self._fragment = frag - # set fragment pattern - self._fragment.pattern = self._node_structs - GlobalContext().add_module_fragment(op, frag) - def init_args_translator(self): """Initialize the Args Translator for the module.""" var_name = self.ms_var_name diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py index 1a7f3dfb..5389e083 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py @@ -14,7 +14,10 @@ # ============================================================================== """Introduce some standard pattern into MindConverter.""" -__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] +__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern", + "USER_DEFINED_PATTERN", "user_defined_pattern"] + +from collections import OrderedDict from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name @@ -22,6 +25,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern BUILT_IN_PATTERN = dict() +USER_DEFINED_PATTERN = OrderedDict() def is_built_in_pattern(pattern: Pattern): @@ -75,6 +79,26 @@ def register_pattern(ptn_name, in_degree, out_degree): return _reg +def user_defined_pattern(pattern_name: str): + """ + Register user define pattern to MindConverter. + + Args: + pattern_name (str): Pattern name. + """ + + def _f(ptn): + pattern = ptn() + if not pattern: + raise ValueError("`ptn` cannot be None.") + if not pattern_name: + raise ValueError("`pattern_name` cannot be None.") + USER_DEFINED_PATTERN[pattern_name] = pattern + return ptn + + return _f + + @register_pattern("ConvBnClip", 1, 1) def _conv_bn_clip(): """Add conv-bn-clip pattern.""" diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py index 6c197d96..7d11fe8d 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py @@ -424,7 +424,18 @@ def _post_process_overlap(patterns) -> Dict: return patterns -class SearchPath: +class BasePath: + """Base class of SearchPath (auto-search) and ReplacePath (greedy-match).""" + + def __init__(self, pattern, sequence: List, prev_path=None): + self.pattern = pattern + self.recursion_path = prev_path.recursion_path[:] if prev_path is not None else list() + if prev_path is not None: + self.recursion_path.append(prev_path) + self.topo_order_bef_repl = sequence + + +class SearchPath(BasePath): """ Use SearchPath to store the search path. @@ -439,15 +450,9 @@ class SearchPath: def __init__(self, pattern, sequence: List[BaseNode], prev_path=None, graph=None, sub_graph_size: int = 2): - self.pattern = pattern + super(SearchPath, self).__init__(pattern, sequence, prev_path) self.graph = copy.copy(prev_path.graph) if prev_path is not None \ else copy.copy(graph) - self.recursion_path = prev_path.recursion_path[:] \ - if prev_path is not None else list() - if prev_path is not None: - self.recursion_path.append(prev_path) - - self.topo_order_bef_repl = sequence self.topo_order_aft_repl, self.inverted_index = self._create_new_order() self.node_collection = dict() self.hash_of_aft_repl = gen_hash_key(self.topo_order_aft_repl) @@ -689,3 +694,28 @@ class SearchPath: f"H: {self.heuristic_v}, G: {self.actual_v}, E: {self.evaluate_score()}" return repr_str + + +class ReplacePath(BasePath): + """Data struct of replacing path with greedy matching.""" + + def __init__(self, pattern, sequence: List, prev_path=None): + super(ReplacePath, self).__init__(pattern, sequence, prev_path) + self.topo_order_aft_repl = None + + def replace(self, increment_idx): + """ + Greedy matching. + + Args: + increment_idx (int): To deduplicate module name. + """ + src = ",".join(self.topo_order_bef_repl) + tgt = self.pattern.pattern + md_name = f"Module{increment_idx}" + src_aft_repl = src.replace(tgt, md_name) + if src != src_aft_repl: + self.pattern.module_name = md_name + self.topo_order_aft_repl = src_aft_repl.split(",") + return md_name + return None diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index d92b3bf8..edc86247 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -16,6 +16,7 @@ from queue import PriorityQueue, Queue from typing import Dict, List +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import USER_DEFINED_PATTERN from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ pattern_fuzzy_matching from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ @@ -25,7 +26,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ - generate_pattern, find_built_in_pattern + generate_pattern, find_built_in_pattern, ReplacePath from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError @@ -371,6 +372,7 @@ def _build_connection(loader): context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) context.successor_table[node_name] = list(node.get_successor_dict().keys()) context.outputs_table[node_name] = node.output_name_list + # Record the model inputs count, use it to control the search algorithm. context.has_multi_inputs = len(loader.input_nodes) > 1 dag = DagGraph(nodes=context.node_collection.copy(), @@ -426,6 +428,28 @@ def _add_known_module_name(search_path): return ctx +def greedy_match(topo_order, user_defined_ptn): + """ + Greedy replace topological order with given pattern by user. + + Args: + topo_order (list[str]): Topological order sequence. + user_defined_ptn (dict): User defined pattern. + """ + increment_idx = 0 + prev_path = None + for md_name, ptn in user_defined_ptn: + ptn = Pattern(",".join(ptn), len(ptn), -1, -1, ptn) + ptn.known_module_name = md_name + topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl + repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path) + module_name = repl_path.replace(increment_idx) + if module_name is not None: + increment_idx += 1 + prev_path = repl_path + return prev_path + + @SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.") def generate_scope_name(data_loader): """ @@ -439,6 +463,12 @@ def generate_scope_name(data_loader): """ init_dag = _build_connection(data_loader) try: + if USER_DEFINED_PATTERN: + topo_order = [node for _, node in context.node_collection.items()] + repl_path = greedy_match(topo_order, USER_DEFINED_PATTERN) + topo_order_with_scope_name_list = _retrieve_scope_name(repl_path) if repl_path else flatten_graph(init_dag) + return topo_order_with_scope_name_list + result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index ef0e033c..c36f0a0f 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -15,14 +15,11 @@ """Define graph entity.""" import abc from collections import OrderedDict -from copy import deepcopy from typing import List from mindinsight.mindconverter.common.log import logger as log -from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment -from mindinsight.mindconverter.graph_based_converter.constant import NodeType, InputType -from mindinsight.mindconverter.graph_based_converter.mapper.base import Mapper +from mindinsight.mindconverter.graph_based_converter.constant import InputType from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError @@ -574,56 +571,3 @@ class GraphNode(abc.ABC): ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) return ipt_args_settings_in_construct - - def param_transform(self, mapper: Mapper, variable_name): - """ - Transform param in PyTorch operation into MindSpore. - - Args: - variable_name (str): Variable name. - mapper (ONNXToMindSporeMapper): Mapper between onnx operation - and MindSpore. - - Returns: - dict, transformed params. - """ - if self._node_type != NodeType.OPERATION.value: - args = deepcopy(self._args_in_code) - self._args_in_code = dict() - for arg, value in args.items(): - self._args_in_code[self._get_arg_name(arg, variable_name)] = value - return CodeFragment(operation="", actual_args=args, settings=None, - input_shape=self.input_shape, output_shape=self.output_shape) - - if self.transformed: - raise ValueError("Already transformed.") - - params = deepcopy(self._op_params) - params.update({"input_shape": self.input_shape, - "output_shape": self.output_shape}) - - ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name, - params=params, - weights=self._weight) - - if ms_op: - code_fragment = CodeFragment(operation=ms_op, - actual_args=ms_params, - settings=ms_settings, - input_shape=self.input_shape, - output_shape=self.output_shape, - trainable_params=ms_weights) - else: - code_fragment = CodeFragment(operation=self._op_name, - actual_args=self._op_params, - settings=None, - input_shape=self.input_shape, - output_shape=self.output_shape, - trainable_params=self._weight) - - for arg, value in code_fragment.actual_args.items(): - self._args_in_code[self._get_arg_name(arg, variable_name)] = value - - self.transformed = True - - return code_fragment diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index fad0f900..0a9bcc30 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -261,7 +261,6 @@ class OnnxDataLoader: self.model = onnx_model_sim self.graph = onnx_model_sim.graph self.nodes = onnx_model_sim.graph.node - self.batch_size = list(input_nodes.values())[0][0] self.input_nodes = input_nodes self.output_nodes = output_nodes # args for init @@ -274,6 +273,9 @@ class OnnxDataLoader: self.tensors_dict = {} # {tensor_name: OnnxTensor} self.value_info_dict = {} # Not contains input and output nodes + # Record the weight names used many times. + self.repeated_weight = dict() + self.node_output_shape_dict = OrderedDict() # {node_name: [int]} # Key is edge of ONNX ir graph, value is the corresponding precursor node. @@ -362,6 +364,7 @@ class OnnxDataLoader: def _parse_nodes(self): """Parse each onnx nodes in the model.""" nodes_topo_idx = [] + record_tensors = dict() for idx, node in enumerate(self.nodes): if not node.name: node.name = "_".join(node.output) @@ -377,11 +380,21 @@ class OnnxDataLoader: self._global_context.onnx_node_inputs[n.name].append(ipt_nd) else: self._global_context.onnx_node_inputs[n.name] = [ipt_nd] + if ipt_nd in self.tensors_dict: + if ipt_nd not in record_tensors: + record_tensors[ipt_nd] = [node.name] + continue + record_tensors[ipt_nd].append(node.name) + self.repeated_weight.setdefault(ipt_nd, []) self._global_context.onnx_node_name_to_topo_idx[n.name] = idx + for k in self.repeated_weight: + self.repeated_weight[k] = record_tensors[k][:] + self._global_context.onnx_nodes_collection = self._nodes_dict self._global_context.onnx_nodes_topo_index = nodes_topo_idx + self._global_context.repeated_weights = self.repeated_weight def _parse_tensors(self): """Parse each onnx tensors in the model.""" @@ -405,8 +418,7 @@ class OnnxDataLoader: for i, s in enumerate(shape): if 'unk' in s: # Have to adapt user-define axis name, e.g. 'sequence', 'batch'. - shape[i] = self.batch_size if self.batch_size is not None else 1 - continue + raise ValueError(f"cannot get shape of {node_opt_name}.") if s == "scalar": shape = SCALAR_WITHOUT_SHAPE continue