From: @liuchongming74 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -163,61 +163,6 @@ class Fragment(abc.ABC): | |||
| return self._output_shape | |||
| class CodeFragment(Fragment): | |||
| """ | |||
| Manage the variables related with code generation. | |||
| For single operation type node, the variables in `CodeLine` stands for: | |||
| ```python | |||
| class Module(nn.Cell): | |||
| def __init__ (self, ...): | |||
| super(Module, self).__init__() | |||
| self.<CodeLine.declared_variable_name> = <CodeLine.operation>(<CodeLine.scalar_args>, | |||
| <CodeLine.init_trainable_params>) | |||
| self.<CodeLine.trainable_params[k].param_name> = Tensor(<CodeLine.trainable_params[k].shape>, | |||
| dtype=<CodeLine._trainable_params[k].dtype>) | |||
| def construct(self, x, ...): | |||
| <CodeLine.output_var_name> = self.<CodeLine.declared_variable_name>(<CodeLine.operation_inputs>) | |||
| ... | |||
| return output | |||
| ``` | |||
| Args: | |||
| operation (str): Operation name in MindSpore. | |||
| actual_args (dict): Actual arg values. | |||
| settings (namedTuple): Code generation setting. | |||
| """ | |||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape, | |||
| trainable_params=None, trainable_weights=None): | |||
| super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, | |||
| input_shape=input_shape, output_shape=output_shape, | |||
| settings=settings) | |||
| self._trainable_params = dict() # External weights, like Matmul. | |||
| self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. | |||
| self._trainable_weights = trainable_weights | |||
| @property | |||
| def trainable_params(self): | |||
| """Return the trainable parameters.""" | |||
| return self._trainable_params | |||
| @property | |||
| def trainable_weights(self): | |||
| return self._trainable_weights | |||
| class ModuleFragment(Fragment): | |||
| """Manage module type code variables.""" | |||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape): | |||
| super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, | |||
| input_shape=input_shape, output_shape=output_shape, | |||
| settings=settings) | |||
| class NewFragment: | |||
| """ | |||
| Fragment definition for MindSpore code generation. | |||
| @@ -310,6 +255,12 @@ class NewFragment: | |||
| return f"{opt}[{inner_idx}]" | |||
| return opt | |||
| @staticmethod | |||
| def create_parameter(weight_shape, weight_dtype): | |||
| """Create a parameter code line.""" | |||
| return f"Parameter(Tensor(np.random.uniform(0, 1, {weight_shape}).astype(np.{weight_dtype})), " \ | |||
| f"name=None)" | |||
| def __call__(self) -> Tuple[List[str], List[str]]: | |||
| """ | |||
| Define parameter rewrite function. | |||
| @@ -334,6 +285,10 @@ class NewFragment: | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value) | |||
| return init_stats, call_stats | |||
| def register_parameter(self, var, line): | |||
| """Append a new parameter into template.""" | |||
| self._code_template[var][TemplateKeywords.INIT.value].append(line) | |||
| @staticmethod | |||
| def _rewrite(var, data, template: str) -> str: | |||
| """ | |||
| @@ -353,6 +308,12 @@ class NewFragment: | |||
| data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) | |||
| if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: | |||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) | |||
| if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data: | |||
| rewrite_params = { | |||
| f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value][slot] | |||
| for slot in data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value] | |||
| } | |||
| rewrite_data.update(rewrite_params) | |||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value]) | |||
| return template.format(**{ | |||
| k: str(rewrite_data[k]) for k in rewrite_data | |||
| @@ -81,6 +81,9 @@ class GlobalContext(metaclass=Singleton): | |||
| self.outputs_storage = OutputStorage() | |||
| # Record weights name that used many times. | |||
| self.repeated_weights = dict() | |||
| def get_onnx_node_from_identifier(self, identifier): | |||
| """Return an OnnxUtils defined node by its identifier.""" | |||
| onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) | |||
| @@ -107,6 +107,7 @@ class ExchangeMessageKeywords(Enum): | |||
| ARGS = "args" | |||
| WEIGHTS = "weights" | |||
| TRAINABLE_PARAMS = "trainable_params" | |||
| PARAMETERS_DECLARED = "parameters" | |||
| BINARY_HEADER_PYTORCH_FILE = \ | |||
| @@ -21,7 +21,6 @@ from mindinsight.mindconverter.graph_based_converter.generator.node_struct impor | |||
| from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value | |||
| from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr | |||
| @@ -184,28 +183,6 @@ class ModuleStruct: | |||
| "initialized": self.initialized | |||
| }) | |||
| def init_module_fragment(self): | |||
| """Init the module fragment.""" | |||
| if not self.initialized: | |||
| return | |||
| # check if fragment exists in global context | |||
| op = "Module{}".format(self.pattern_id) | |||
| if op == "Module-1": # reset as Main Model's op name | |||
| op = "Model" | |||
| frag = GlobalContext().get_module_fragment(op) | |||
| if frag is not None: # use exists fragment | |||
| self._fragment = frag | |||
| else: | |||
| frag = ModuleFragment(operation=op, | |||
| actual_args=None, | |||
| input_shape=None, | |||
| output_shape=None, | |||
| settings=None) | |||
| self._fragment = frag | |||
| # set fragment pattern | |||
| self._fragment.pattern = self._node_structs | |||
| GlobalContext().add_module_fragment(op, frag) | |||
| def init_args_translator(self): | |||
| """Initialize the Args Translator for the module.""" | |||
| var_name = self.ms_var_name | |||
| @@ -14,7 +14,10 @@ | |||
| # ============================================================================== | |||
| """Introduce some standard pattern into MindConverter.""" | |||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] | |||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern", | |||
| "USER_DEFINED_PATTERN", "user_defined_pattern"] | |||
| from collections import OrderedDict | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name | |||
| @@ -22,6 +25,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern | |||
| BUILT_IN_PATTERN = dict() | |||
| USER_DEFINED_PATTERN = OrderedDict() | |||
| def is_built_in_pattern(pattern: Pattern): | |||
| @@ -75,6 +79,26 @@ def register_pattern(ptn_name, in_degree, out_degree): | |||
| return _reg | |||
| def user_defined_pattern(pattern_name: str): | |||
| """ | |||
| Register user define pattern to MindConverter. | |||
| Args: | |||
| pattern_name (str): Pattern name. | |||
| """ | |||
| def _f(ptn): | |||
| pattern = ptn() | |||
| if not pattern: | |||
| raise ValueError("`ptn` cannot be None.") | |||
| if not pattern_name: | |||
| raise ValueError("`pattern_name` cannot be None.") | |||
| USER_DEFINED_PATTERN[pattern_name] = pattern | |||
| return ptn | |||
| return _f | |||
| @register_pattern("ConvBnClip", 1, 1) | |||
| def _conv_bn_clip(): | |||
| """Add conv-bn-clip pattern.""" | |||
| @@ -424,7 +424,18 @@ def _post_process_overlap(patterns) -> Dict: | |||
| return patterns | |||
| class SearchPath: | |||
| class BasePath: | |||
| """Base class of SearchPath (auto-search) and ReplacePath (greedy-match).""" | |||
| def __init__(self, pattern, sequence: List, prev_path=None): | |||
| self.pattern = pattern | |||
| self.recursion_path = prev_path.recursion_path[:] if prev_path is not None else list() | |||
| if prev_path is not None: | |||
| self.recursion_path.append(prev_path) | |||
| self.topo_order_bef_repl = sequence | |||
| class SearchPath(BasePath): | |||
| """ | |||
| Use SearchPath to store the search path. | |||
| @@ -439,15 +450,9 @@ class SearchPath: | |||
| def __init__(self, pattern, sequence: List[BaseNode], prev_path=None, | |||
| graph=None, sub_graph_size: int = 2): | |||
| self.pattern = pattern | |||
| super(SearchPath, self).__init__(pattern, sequence, prev_path) | |||
| self.graph = copy.copy(prev_path.graph) if prev_path is not None \ | |||
| else copy.copy(graph) | |||
| self.recursion_path = prev_path.recursion_path[:] \ | |||
| if prev_path is not None else list() | |||
| if prev_path is not None: | |||
| self.recursion_path.append(prev_path) | |||
| self.topo_order_bef_repl = sequence | |||
| self.topo_order_aft_repl, self.inverted_index = self._create_new_order() | |||
| self.node_collection = dict() | |||
| self.hash_of_aft_repl = gen_hash_key(self.topo_order_aft_repl) | |||
| @@ -689,3 +694,28 @@ class SearchPath: | |||
| f"H: {self.heuristic_v}, G: {self.actual_v}, E: {self.evaluate_score()}" | |||
| return repr_str | |||
| class ReplacePath(BasePath): | |||
| """Data struct of replacing path with greedy matching.""" | |||
| def __init__(self, pattern, sequence: List, prev_path=None): | |||
| super(ReplacePath, self).__init__(pattern, sequence, prev_path) | |||
| self.topo_order_aft_repl = None | |||
| def replace(self, increment_idx): | |||
| """ | |||
| Greedy matching. | |||
| Args: | |||
| increment_idx (int): To deduplicate module name. | |||
| """ | |||
| src = ",".join(self.topo_order_bef_repl) | |||
| tgt = self.pattern.pattern | |||
| md_name = f"Module{increment_idx}" | |||
| src_aft_repl = src.replace(tgt, md_name) | |||
| if src != src_aft_repl: | |||
| self.pattern.module_name = md_name | |||
| self.topo_order_aft_repl = src_aft_repl.split(",") | |||
| return md_name | |||
| return None | |||
| @@ -16,6 +16,7 @@ | |||
| from queue import PriorityQueue, Queue | |||
| from typing import Dict, List | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import USER_DEFINED_PATTERN | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ | |||
| pattern_fuzzy_matching | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ | |||
| @@ -25,7 +26,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ | |||
| generate_pattern, find_built_in_pattern | |||
| generate_pattern, find_built_in_pattern, ReplacePath | |||
| from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError | |||
| @@ -371,6 +372,7 @@ def _build_connection(loader): | |||
| context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) | |||
| context.successor_table[node_name] = list(node.get_successor_dict().keys()) | |||
| context.outputs_table[node_name] = node.output_name_list | |||
| # Record the model inputs count, use it to control the search algorithm. | |||
| context.has_multi_inputs = len(loader.input_nodes) > 1 | |||
| dag = DagGraph(nodes=context.node_collection.copy(), | |||
| @@ -426,6 +428,28 @@ def _add_known_module_name(search_path): | |||
| return ctx | |||
| def greedy_match(topo_order, user_defined_ptn): | |||
| """ | |||
| Greedy replace topological order with given pattern by user. | |||
| Args: | |||
| topo_order (list[str]): Topological order sequence. | |||
| user_defined_ptn (dict): User defined pattern. | |||
| """ | |||
| increment_idx = 0 | |||
| prev_path = None | |||
| for md_name, ptn in user_defined_ptn: | |||
| ptn = Pattern(",".join(ptn), len(ptn), -1, -1, ptn) | |||
| ptn.known_module_name = md_name | |||
| topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl | |||
| repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path) | |||
| module_name = repl_path.replace(increment_idx) | |||
| if module_name is not None: | |||
| increment_idx += 1 | |||
| prev_path = repl_path | |||
| return prev_path | |||
| @SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.") | |||
| def generate_scope_name(data_loader): | |||
| """ | |||
| @@ -439,6 +463,12 @@ def generate_scope_name(data_loader): | |||
| """ | |||
| init_dag = _build_connection(data_loader) | |||
| try: | |||
| if USER_DEFINED_PATTERN: | |||
| topo_order = [node for _, node in context.node_collection.items()] | |||
| repl_path = greedy_match(topo_order, USER_DEFINED_PATTERN) | |||
| topo_order_with_scope_name_list = _retrieve_scope_name(repl_path) if repl_path else flatten_graph(init_dag) | |||
| return topo_order_with_scope_name_list | |||
| result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) | |||
| topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag) | |||
| @@ -15,14 +15,11 @@ | |||
| """Define graph entity.""" | |||
| import abc | |||
| from collections import OrderedDict | |||
| from copy import deepcopy | |||
| from typing import List | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment | |||
| from mindinsight.mindconverter.graph_based_converter.constant import NodeType, InputType | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import Mapper | |||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | |||
| from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError | |||
| @@ -574,56 +571,3 @@ class GraphNode(abc.ABC): | |||
| ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) | |||
| return ipt_args_settings_in_construct | |||
| def param_transform(self, mapper: Mapper, variable_name): | |||
| """ | |||
| Transform param in PyTorch operation into MindSpore. | |||
| Args: | |||
| variable_name (str): Variable name. | |||
| mapper (ONNXToMindSporeMapper): Mapper between onnx operation | |||
| and MindSpore. | |||
| Returns: | |||
| dict, transformed params. | |||
| """ | |||
| if self._node_type != NodeType.OPERATION.value: | |||
| args = deepcopy(self._args_in_code) | |||
| self._args_in_code = dict() | |||
| for arg, value in args.items(): | |||
| self._args_in_code[self._get_arg_name(arg, variable_name)] = value | |||
| return CodeFragment(operation="", actual_args=args, settings=None, | |||
| input_shape=self.input_shape, output_shape=self.output_shape) | |||
| if self.transformed: | |||
| raise ValueError("Already transformed.") | |||
| params = deepcopy(self._op_params) | |||
| params.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name, | |||
| params=params, | |||
| weights=self._weight) | |||
| if ms_op: | |||
| code_fragment = CodeFragment(operation=ms_op, | |||
| actual_args=ms_params, | |||
| settings=ms_settings, | |||
| input_shape=self.input_shape, | |||
| output_shape=self.output_shape, | |||
| trainable_params=ms_weights) | |||
| else: | |||
| code_fragment = CodeFragment(operation=self._op_name, | |||
| actual_args=self._op_params, | |||
| settings=None, | |||
| input_shape=self.input_shape, | |||
| output_shape=self.output_shape, | |||
| trainable_params=self._weight) | |||
| for arg, value in code_fragment.actual_args.items(): | |||
| self._args_in_code[self._get_arg_name(arg, variable_name)] = value | |||
| self.transformed = True | |||
| return code_fragment | |||
| @@ -262,7 +262,6 @@ class OnnxDataLoader: | |||
| self.model = onnx_model_sim | |||
| self.graph = onnx_model_sim.graph | |||
| self.nodes = onnx_model_sim.graph.node | |||
| self.batch_size = list(input_nodes.values())[0][0] | |||
| self.input_nodes = input_nodes | |||
| self.output_nodes = output_nodes | |||
| # args for init | |||
| @@ -275,6 +274,9 @@ class OnnxDataLoader: | |||
| self.tensors_dict = {} # {tensor_name: OnnxTensor} | |||
| self.value_info_dict = {} # Not contains input and output nodes | |||
| # Record the weight names used many times. | |||
| self.repeated_weight = dict() | |||
| self.node_output_shape_dict = OrderedDict() # {node_name: [int]} | |||
| # Key is edge of ONNX ir graph, value is the corresponding precursor node. | |||
| @@ -393,6 +395,7 @@ class OnnxDataLoader: | |||
| def _parse_nodes(self): | |||
| """Parse each onnx nodes in the model.""" | |||
| nodes_topo_idx = [] | |||
| record_tensors = dict() | |||
| for idx, node in enumerate(self.nodes): | |||
| if not node.name: | |||
| node.name = "_".join(node.output) | |||
| @@ -408,11 +411,21 @@ class OnnxDataLoader: | |||
| self._global_context.onnx_node_inputs[n.name].append(ipt_nd) | |||
| else: | |||
| self._global_context.onnx_node_inputs[n.name] = [ipt_nd] | |||
| if ipt_nd in self.tensors_dict: | |||
| if ipt_nd not in record_tensors: | |||
| record_tensors[ipt_nd] = [node.name] | |||
| continue | |||
| record_tensors[ipt_nd].append(node.name) | |||
| self.repeated_weight.setdefault(ipt_nd, []) | |||
| self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | |||
| for k in self.repeated_weight: | |||
| self.repeated_weight[k] = record_tensors[k][:] | |||
| self._global_context.onnx_nodes_collection = self._nodes_dict | |||
| self._global_context.onnx_nodes_topo_index = nodes_topo_idx | |||
| self._global_context.repeated_weights = self.repeated_weight | |||
| def _parse_tensors(self): | |||
| """Parse each onnx tensors in the model.""" | |||
| @@ -436,8 +449,7 @@ class OnnxDataLoader: | |||
| for i, s in enumerate(shape): | |||
| if 'unk' in s: | |||
| # Have to adapt user-define axis name, e.g. 'sequence', 'batch'. | |||
| shape[i] = self.batch_size if self.batch_size is not None else 1 | |||
| continue | |||
| raise ValueError(f"cannot get shape of {node_opt_name}.") | |||
| if s == "scalar": | |||
| shape = SCALAR_WITHOUT_SHAPE | |||
| continue | |||