From: @liuchongming74 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -163,61 +163,6 @@ class Fragment(abc.ABC): | |||||
| return self._output_shape | return self._output_shape | ||||
| class CodeFragment(Fragment): | |||||
| """ | |||||
| Manage the variables related with code generation. | |||||
| For single operation type node, the variables in `CodeLine` stands for: | |||||
| ```python | |||||
| class Module(nn.Cell): | |||||
| def __init__ (self, ...): | |||||
| super(Module, self).__init__() | |||||
| self.<CodeLine.declared_variable_name> = <CodeLine.operation>(<CodeLine.scalar_args>, | |||||
| <CodeLine.init_trainable_params>) | |||||
| self.<CodeLine.trainable_params[k].param_name> = Tensor(<CodeLine.trainable_params[k].shape>, | |||||
| dtype=<CodeLine._trainable_params[k].dtype>) | |||||
| def construct(self, x, ...): | |||||
| <CodeLine.output_var_name> = self.<CodeLine.declared_variable_name>(<CodeLine.operation_inputs>) | |||||
| ... | |||||
| return output | |||||
| ``` | |||||
| Args: | |||||
| operation (str): Operation name in MindSpore. | |||||
| actual_args (dict): Actual arg values. | |||||
| settings (namedTuple): Code generation setting. | |||||
| """ | |||||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape, | |||||
| trainable_params=None, trainable_weights=None): | |||||
| super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, | |||||
| input_shape=input_shape, output_shape=output_shape, | |||||
| settings=settings) | |||||
| self._trainable_params = dict() # External weights, like Matmul. | |||||
| self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. | |||||
| self._trainable_weights = trainable_weights | |||||
| @property | |||||
| def trainable_params(self): | |||||
| """Return the trainable parameters.""" | |||||
| return self._trainable_params | |||||
| @property | |||||
| def trainable_weights(self): | |||||
| return self._trainable_weights | |||||
| class ModuleFragment(Fragment): | |||||
| """Manage module type code variables.""" | |||||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape): | |||||
| super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, | |||||
| input_shape=input_shape, output_shape=output_shape, | |||||
| settings=settings) | |||||
| class NewFragment: | class NewFragment: | ||||
| """ | """ | ||||
| Fragment definition for MindSpore code generation. | Fragment definition for MindSpore code generation. | ||||
| @@ -310,6 +255,12 @@ class NewFragment: | |||||
| return f"{opt}[{inner_idx}]" | return f"{opt}[{inner_idx}]" | ||||
| return opt | return opt | ||||
| @staticmethod | |||||
| def create_parameter(weight_shape, weight_dtype): | |||||
| """Create a parameter code line.""" | |||||
| return f"Parameter(Tensor(np.random.uniform(0, 1, {weight_shape}).astype(np.{weight_dtype})), " \ | |||||
| f"name=None)" | |||||
| def __call__(self) -> Tuple[List[str], List[str]]: | def __call__(self) -> Tuple[List[str], List[str]]: | ||||
| """ | """ | ||||
| Define parameter rewrite function. | Define parameter rewrite function. | ||||
| @@ -334,6 +285,10 @@ class NewFragment: | |||||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value) | ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value) | ||||
| return init_stats, call_stats | return init_stats, call_stats | ||||
| def register_parameter(self, var, line): | |||||
| """Append a new parameter into template.""" | |||||
| self._code_template[var][TemplateKeywords.INIT.value].append(line) | |||||
| @staticmethod | @staticmethod | ||||
| def _rewrite(var, data, template: str) -> str: | def _rewrite(var, data, template: str) -> str: | ||||
| """ | """ | ||||
| @@ -353,6 +308,12 @@ class NewFragment: | |||||
| data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) | data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) | ||||
| if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: | if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: | ||||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) | rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) | ||||
| if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data: | |||||
| rewrite_params = { | |||||
| f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value][slot] | |||||
| for slot in data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value] | |||||
| } | |||||
| rewrite_data.update(rewrite_params) | |||||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value]) | rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value]) | ||||
| return template.format(**{ | return template.format(**{ | ||||
| k: str(rewrite_data[k]) for k in rewrite_data | k: str(rewrite_data[k]) for k in rewrite_data | ||||
| @@ -81,6 +81,9 @@ class GlobalContext(metaclass=Singleton): | |||||
| self.outputs_storage = OutputStorage() | self.outputs_storage = OutputStorage() | ||||
| # Record weights name that used many times. | |||||
| self.repeated_weights = dict() | |||||
| def get_onnx_node_from_identifier(self, identifier): | def get_onnx_node_from_identifier(self, identifier): | ||||
| """Return an OnnxUtils defined node by its identifier.""" | """Return an OnnxUtils defined node by its identifier.""" | ||||
| onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) | onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) | ||||
| @@ -107,6 +107,7 @@ class ExchangeMessageKeywords(Enum): | |||||
| ARGS = "args" | ARGS = "args" | ||||
| WEIGHTS = "weights" | WEIGHTS = "weights" | ||||
| TRAINABLE_PARAMS = "trainable_params" | TRAINABLE_PARAMS = "trainable_params" | ||||
| PARAMETERS_DECLARED = "parameters" | |||||
| BINARY_HEADER_PYTORCH_FILE = \ | BINARY_HEADER_PYTORCH_FILE = \ | ||||
| @@ -21,7 +21,6 @@ from mindinsight.mindconverter.graph_based_converter.generator.node_struct impor | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value | from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value | ||||
| from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation | from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation | ||||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment | |||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | ||||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr | from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr | ||||
| @@ -184,28 +183,6 @@ class ModuleStruct: | |||||
| "initialized": self.initialized | "initialized": self.initialized | ||||
| }) | }) | ||||
| def init_module_fragment(self): | |||||
| """Init the module fragment.""" | |||||
| if not self.initialized: | |||||
| return | |||||
| # check if fragment exists in global context | |||||
| op = "Module{}".format(self.pattern_id) | |||||
| if op == "Module-1": # reset as Main Model's op name | |||||
| op = "Model" | |||||
| frag = GlobalContext().get_module_fragment(op) | |||||
| if frag is not None: # use exists fragment | |||||
| self._fragment = frag | |||||
| else: | |||||
| frag = ModuleFragment(operation=op, | |||||
| actual_args=None, | |||||
| input_shape=None, | |||||
| output_shape=None, | |||||
| settings=None) | |||||
| self._fragment = frag | |||||
| # set fragment pattern | |||||
| self._fragment.pattern = self._node_structs | |||||
| GlobalContext().add_module_fragment(op, frag) | |||||
| def init_args_translator(self): | def init_args_translator(self): | ||||
| """Initialize the Args Translator for the module.""" | """Initialize the Args Translator for the module.""" | ||||
| var_name = self.ms_var_name | var_name = self.ms_var_name | ||||
| @@ -14,7 +14,10 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Introduce some standard pattern into MindConverter.""" | """Introduce some standard pattern into MindConverter.""" | ||||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] | |||||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern", | |||||
| "USER_DEFINED_PATTERN", "user_defined_pattern"] | |||||
| from collections import OrderedDict | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name | ||||
| @@ -22,6 +25,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern | ||||
| BUILT_IN_PATTERN = dict() | BUILT_IN_PATTERN = dict() | ||||
| USER_DEFINED_PATTERN = OrderedDict() | |||||
| def is_built_in_pattern(pattern: Pattern): | def is_built_in_pattern(pattern: Pattern): | ||||
| @@ -75,6 +79,26 @@ def register_pattern(ptn_name, in_degree, out_degree): | |||||
| return _reg | return _reg | ||||
| def user_defined_pattern(pattern_name: str): | |||||
| """ | |||||
| Register user define pattern to MindConverter. | |||||
| Args: | |||||
| pattern_name (str): Pattern name. | |||||
| """ | |||||
| def _f(ptn): | |||||
| pattern = ptn() | |||||
| if not pattern: | |||||
| raise ValueError("`ptn` cannot be None.") | |||||
| if not pattern_name: | |||||
| raise ValueError("`pattern_name` cannot be None.") | |||||
| USER_DEFINED_PATTERN[pattern_name] = pattern | |||||
| return ptn | |||||
| return _f | |||||
| @register_pattern("ConvBnClip", 1, 1) | @register_pattern("ConvBnClip", 1, 1) | ||||
| def _conv_bn_clip(): | def _conv_bn_clip(): | ||||
| """Add conv-bn-clip pattern.""" | """Add conv-bn-clip pattern.""" | ||||
| @@ -424,7 +424,18 @@ def _post_process_overlap(patterns) -> Dict: | |||||
| return patterns | return patterns | ||||
| class SearchPath: | |||||
| class BasePath: | |||||
| """Base class of SearchPath (auto-search) and ReplacePath (greedy-match).""" | |||||
| def __init__(self, pattern, sequence: List, prev_path=None): | |||||
| self.pattern = pattern | |||||
| self.recursion_path = prev_path.recursion_path[:] if prev_path is not None else list() | |||||
| if prev_path is not None: | |||||
| self.recursion_path.append(prev_path) | |||||
| self.topo_order_bef_repl = sequence | |||||
| class SearchPath(BasePath): | |||||
| """ | """ | ||||
| Use SearchPath to store the search path. | Use SearchPath to store the search path. | ||||
| @@ -439,15 +450,9 @@ class SearchPath: | |||||
| def __init__(self, pattern, sequence: List[BaseNode], prev_path=None, | def __init__(self, pattern, sequence: List[BaseNode], prev_path=None, | ||||
| graph=None, sub_graph_size: int = 2): | graph=None, sub_graph_size: int = 2): | ||||
| self.pattern = pattern | |||||
| super(SearchPath, self).__init__(pattern, sequence, prev_path) | |||||
| self.graph = copy.copy(prev_path.graph) if prev_path is not None \ | self.graph = copy.copy(prev_path.graph) if prev_path is not None \ | ||||
| else copy.copy(graph) | else copy.copy(graph) | ||||
| self.recursion_path = prev_path.recursion_path[:] \ | |||||
| if prev_path is not None else list() | |||||
| if prev_path is not None: | |||||
| self.recursion_path.append(prev_path) | |||||
| self.topo_order_bef_repl = sequence | |||||
| self.topo_order_aft_repl, self.inverted_index = self._create_new_order() | self.topo_order_aft_repl, self.inverted_index = self._create_new_order() | ||||
| self.node_collection = dict() | self.node_collection = dict() | ||||
| self.hash_of_aft_repl = gen_hash_key(self.topo_order_aft_repl) | self.hash_of_aft_repl = gen_hash_key(self.topo_order_aft_repl) | ||||
| @@ -689,3 +694,28 @@ class SearchPath: | |||||
| f"H: {self.heuristic_v}, G: {self.actual_v}, E: {self.evaluate_score()}" | f"H: {self.heuristic_v}, G: {self.actual_v}, E: {self.evaluate_score()}" | ||||
| return repr_str | return repr_str | ||||
| class ReplacePath(BasePath): | |||||
| """Data struct of replacing path with greedy matching.""" | |||||
| def __init__(self, pattern, sequence: List, prev_path=None): | |||||
| super(ReplacePath, self).__init__(pattern, sequence, prev_path) | |||||
| self.topo_order_aft_repl = None | |||||
| def replace(self, increment_idx): | |||||
| """ | |||||
| Greedy matching. | |||||
| Args: | |||||
| increment_idx (int): To deduplicate module name. | |||||
| """ | |||||
| src = ",".join(self.topo_order_bef_repl) | |||||
| tgt = self.pattern.pattern | |||||
| md_name = f"Module{increment_idx}" | |||||
| src_aft_repl = src.replace(tgt, md_name) | |||||
| if src != src_aft_repl: | |||||
| self.pattern.module_name = md_name | |||||
| self.topo_order_aft_repl = src_aft_repl.split(",") | |||||
| return md_name | |||||
| return None | |||||
| @@ -16,6 +16,7 @@ | |||||
| from queue import PriorityQueue, Queue | from queue import PriorityQueue, Queue | ||||
| from typing import Dict, List | from typing import Dict, List | ||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import USER_DEFINED_PATTERN | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ | ||||
| pattern_fuzzy_matching | pattern_fuzzy_matching | ||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ | ||||
| @@ -25,7 +26,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i | |||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | ||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | ||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ | ||||
| generate_pattern, find_built_in_pattern | |||||
| generate_pattern, find_built_in_pattern, ReplacePath | |||||
| from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError | from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError | ||||
| @@ -371,6 +372,7 @@ def _build_connection(loader): | |||||
| context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) | context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) | ||||
| context.successor_table[node_name] = list(node.get_successor_dict().keys()) | context.successor_table[node_name] = list(node.get_successor_dict().keys()) | ||||
| context.outputs_table[node_name] = node.output_name_list | context.outputs_table[node_name] = node.output_name_list | ||||
| # Record the model inputs count, use it to control the search algorithm. | # Record the model inputs count, use it to control the search algorithm. | ||||
| context.has_multi_inputs = len(loader.input_nodes) > 1 | context.has_multi_inputs = len(loader.input_nodes) > 1 | ||||
| dag = DagGraph(nodes=context.node_collection.copy(), | dag = DagGraph(nodes=context.node_collection.copy(), | ||||
| @@ -426,6 +428,28 @@ def _add_known_module_name(search_path): | |||||
| return ctx | return ctx | ||||
| def greedy_match(topo_order, user_defined_ptn): | |||||
| """ | |||||
| Greedy replace topological order with given pattern by user. | |||||
| Args: | |||||
| topo_order (list[str]): Topological order sequence. | |||||
| user_defined_ptn (dict): User defined pattern. | |||||
| """ | |||||
| increment_idx = 0 | |||||
| prev_path = None | |||||
| for md_name, ptn in user_defined_ptn: | |||||
| ptn = Pattern(",".join(ptn), len(ptn), -1, -1, ptn) | |||||
| ptn.known_module_name = md_name | |||||
| topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl | |||||
| repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path) | |||||
| module_name = repl_path.replace(increment_idx) | |||||
| if module_name is not None: | |||||
| increment_idx += 1 | |||||
| prev_path = repl_path | |||||
| return prev_path | |||||
| @SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.") | @SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.") | ||||
| def generate_scope_name(data_loader): | def generate_scope_name(data_loader): | ||||
| """ | """ | ||||
| @@ -439,6 +463,12 @@ def generate_scope_name(data_loader): | |||||
| """ | """ | ||||
| init_dag = _build_connection(data_loader) | init_dag = _build_connection(data_loader) | ||||
| try: | try: | ||||
| if USER_DEFINED_PATTERN: | |||||
| topo_order = [node for _, node in context.node_collection.items()] | |||||
| repl_path = greedy_match(topo_order, USER_DEFINED_PATTERN) | |||||
| topo_order_with_scope_name_list = _retrieve_scope_name(repl_path) if repl_path else flatten_graph(init_dag) | |||||
| return topo_order_with_scope_name_list | |||||
| result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) | result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) | ||||
| topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag) | topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag) | ||||
| @@ -15,14 +15,11 @@ | |||||
| """Define graph entity.""" | """Define graph entity.""" | ||||
| import abc | import abc | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from copy import deepcopy | |||||
| from typing import List | from typing import List | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import NodeType, InputType | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import Mapper | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | |||||
| from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError | from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError | ||||
| @@ -574,56 +571,3 @@ class GraphNode(abc.ABC): | |||||
| ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) | ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) | ||||
| return ipt_args_settings_in_construct | return ipt_args_settings_in_construct | ||||
| def param_transform(self, mapper: Mapper, variable_name): | |||||
| """ | |||||
| Transform param in PyTorch operation into MindSpore. | |||||
| Args: | |||||
| variable_name (str): Variable name. | |||||
| mapper (ONNXToMindSporeMapper): Mapper between onnx operation | |||||
| and MindSpore. | |||||
| Returns: | |||||
| dict, transformed params. | |||||
| """ | |||||
| if self._node_type != NodeType.OPERATION.value: | |||||
| args = deepcopy(self._args_in_code) | |||||
| self._args_in_code = dict() | |||||
| for arg, value in args.items(): | |||||
| self._args_in_code[self._get_arg_name(arg, variable_name)] = value | |||||
| return CodeFragment(operation="", actual_args=args, settings=None, | |||||
| input_shape=self.input_shape, output_shape=self.output_shape) | |||||
| if self.transformed: | |||||
| raise ValueError("Already transformed.") | |||||
| params = deepcopy(self._op_params) | |||||
| params.update({"input_shape": self.input_shape, | |||||
| "output_shape": self.output_shape}) | |||||
| ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name, | |||||
| params=params, | |||||
| weights=self._weight) | |||||
| if ms_op: | |||||
| code_fragment = CodeFragment(operation=ms_op, | |||||
| actual_args=ms_params, | |||||
| settings=ms_settings, | |||||
| input_shape=self.input_shape, | |||||
| output_shape=self.output_shape, | |||||
| trainable_params=ms_weights) | |||||
| else: | |||||
| code_fragment = CodeFragment(operation=self._op_name, | |||||
| actual_args=self._op_params, | |||||
| settings=None, | |||||
| input_shape=self.input_shape, | |||||
| output_shape=self.output_shape, | |||||
| trainable_params=self._weight) | |||||
| for arg, value in code_fragment.actual_args.items(): | |||||
| self._args_in_code[self._get_arg_name(arg, variable_name)] = value | |||||
| self.transformed = True | |||||
| return code_fragment | |||||
| @@ -262,7 +262,6 @@ class OnnxDataLoader: | |||||
| self.model = onnx_model_sim | self.model = onnx_model_sim | ||||
| self.graph = onnx_model_sim.graph | self.graph = onnx_model_sim.graph | ||||
| self.nodes = onnx_model_sim.graph.node | self.nodes = onnx_model_sim.graph.node | ||||
| self.batch_size = list(input_nodes.values())[0][0] | |||||
| self.input_nodes = input_nodes | self.input_nodes = input_nodes | ||||
| self.output_nodes = output_nodes | self.output_nodes = output_nodes | ||||
| # args for init | # args for init | ||||
| @@ -275,6 +274,9 @@ class OnnxDataLoader: | |||||
| self.tensors_dict = {} # {tensor_name: OnnxTensor} | self.tensors_dict = {} # {tensor_name: OnnxTensor} | ||||
| self.value_info_dict = {} # Not contains input and output nodes | self.value_info_dict = {} # Not contains input and output nodes | ||||
| # Record the weight names used many times. | |||||
| self.repeated_weight = dict() | |||||
| self.node_output_shape_dict = OrderedDict() # {node_name: [int]} | self.node_output_shape_dict = OrderedDict() # {node_name: [int]} | ||||
| # Key is edge of ONNX ir graph, value is the corresponding precursor node. | # Key is edge of ONNX ir graph, value is the corresponding precursor node. | ||||
| @@ -393,6 +395,7 @@ class OnnxDataLoader: | |||||
| def _parse_nodes(self): | def _parse_nodes(self): | ||||
| """Parse each onnx nodes in the model.""" | """Parse each onnx nodes in the model.""" | ||||
| nodes_topo_idx = [] | nodes_topo_idx = [] | ||||
| record_tensors = dict() | |||||
| for idx, node in enumerate(self.nodes): | for idx, node in enumerate(self.nodes): | ||||
| if not node.name: | if not node.name: | ||||
| node.name = "_".join(node.output) | node.name = "_".join(node.output) | ||||
| @@ -408,11 +411,21 @@ class OnnxDataLoader: | |||||
| self._global_context.onnx_node_inputs[n.name].append(ipt_nd) | self._global_context.onnx_node_inputs[n.name].append(ipt_nd) | ||||
| else: | else: | ||||
| self._global_context.onnx_node_inputs[n.name] = [ipt_nd] | self._global_context.onnx_node_inputs[n.name] = [ipt_nd] | ||||
| if ipt_nd in self.tensors_dict: | |||||
| if ipt_nd not in record_tensors: | |||||
| record_tensors[ipt_nd] = [node.name] | |||||
| continue | |||||
| record_tensors[ipt_nd].append(node.name) | |||||
| self.repeated_weight.setdefault(ipt_nd, []) | |||||
| self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | ||||
| for k in self.repeated_weight: | |||||
| self.repeated_weight[k] = record_tensors[k][:] | |||||
| self._global_context.onnx_nodes_collection = self._nodes_dict | self._global_context.onnx_nodes_collection = self._nodes_dict | ||||
| self._global_context.onnx_nodes_topo_index = nodes_topo_idx | self._global_context.onnx_nodes_topo_index = nodes_topo_idx | ||||
| self._global_context.repeated_weights = self.repeated_weight | |||||
| def _parse_tensors(self): | def _parse_tensors(self): | ||||
| """Parse each onnx tensors in the model.""" | """Parse each onnx tensors in the model.""" | ||||
| @@ -436,8 +449,7 @@ class OnnxDataLoader: | |||||
| for i, s in enumerate(shape): | for i, s in enumerate(shape): | ||||
| if 'unk' in s: | if 'unk' in s: | ||||
| # Have to adapt user-define axis name, e.g. 'sequence', 'batch'. | # Have to adapt user-define axis name, e.g. 'sequence', 'batch'. | ||||
| shape[i] = self.batch_size if self.batch_size is not None else 1 | |||||
| continue | |||||
| raise ValueError(f"cannot get shape of {node_opt_name}.") | |||||
| if s == "scalar": | if s == "scalar": | ||||
| shape = SCALAR_WITHOUT_SHAPE | shape = SCALAR_WITHOUT_SHAPE | ||||
| continue | continue | ||||