| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,26 +14,10 @@ | |||
| # ============================================================================== | |||
| """Define CodeLine object.""" | |||
| import abc | |||
| import re | |||
| from typing import List, Tuple | |||
| class TrainableParams: | |||
| """Trainable parameters.""" | |||
| def __init__(self, shape, dtype, reference): | |||
| self.param_name = None | |||
| self.shape = shape | |||
| self.dtype = dtype | |||
| self.reference = reference # Weight name in global npy. | |||
| class CodeSetting: | |||
| """Code generation settings.""" | |||
| def __init__(self): | |||
| self.output_vars_suffix = [] | |||
| self.operation_input_type = None # Construct input type, tensor or list. | |||
| self.operation_extra_input = dict() # `values` in original setting dict. | |||
| self.operation_extra_tensor = None # For `MatMul`, `BiasAdd` op, need a tensor | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| class Fragment(abc.ABC): | |||
| @@ -222,3 +206,144 @@ class ModuleFragment(Fragment): | |||
| super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, | |||
| input_shape=input_shape, output_shape=output_shape, | |||
| settings=settings) | |||
| class NewFragment: | |||
| """ | |||
| Fragment definition for MindSpore code generation. | |||
| Args: | |||
| data_entity (dict): Required data by operations. The format of `data_entity` is as follow: | |||
| { | |||
| "var1": { | |||
| "metadata": { # ONNX Metadata | |||
| "operation": "Conv2d", | |||
| "source": "conv_pw_13/Conv2D", | |||
| "attributes": { | |||
| # Put original onnx attributes here. | |||
| } | |||
| }, | |||
| "variable_name": None, | |||
| "inputs": [], | |||
| "output_type": "tensor" | "array", | |||
| "args": {"in_channels": 768, "out_channels": 1024}, | |||
| "trainable_params": {"weight": "Parameter(Tensor(GLOBAL_W[NAME]))"} | |||
| }, | |||
| "var2": { | |||
| "variable_name": "pad", | |||
| "args": {"padding": [0, 1, 1, 0], "mode": "SAME"} | |||
| } | |||
| } | |||
| code_template (dict): Code template generated by mapper. The format of `code_template` is as follow: | |||
| { | |||
| "var1": { | |||
| "init": [ | |||
| "self.{var1} = nn.Conv2d(in_channels={in_channels})", | |||
| "self.{var1}.weight = {weight}" | |||
| ], | |||
| "construct": [ | |||
| "opt_{var1} = self.{var1}({inputs}[, extra])" | |||
| ] | |||
| }, | |||
| "var2": { | |||
| "init": [ | |||
| "self.{var2} = nn.Pad(padding={padding}, mode={mode})" | |||
| ], | |||
| "construct": [ | |||
| "opt_{var2} = self.{var2}(opt_{var1}[, extra])" | |||
| ] | |||
| } | |||
| } | |||
| outputs (list[str]): Outputs name slot list. | |||
| outputs_mapping (tuple): Outputs index mapping between ir node and MindSpore operation. | |||
| """ | |||
| def __init__(self, data_entity: dict, code_template: dict, outputs: List[str], outputs_mapping): | |||
| self.exchange_msg = data_entity | |||
| self._code_template = code_template | |||
| self.inputs = [] | |||
| self._outputs = outputs | |||
| self.outputs_mapping = outputs_mapping | |||
| self.format_args = dict() | |||
| def _get_outputs(self): | |||
| """ | |||
| Get outputs of the code snippet. | |||
| Returns: | |||
| list[str], outputs of current code block. | |||
| """ | |||
| outputs = [] | |||
| variables = { | |||
| k: self.exchange_msg[k][ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value] | |||
| for k in self.exchange_msg if k != ExchangeMessageKeywords.METADATA.value | |||
| } | |||
| for o in self._outputs: | |||
| extractor = r".*\{(?P<var>.+)\}.*" | |||
| var_def = re.match(extractor, o) | |||
| if not var_def: | |||
| raise ValueError(f"Output variable name {o} is illegal.") | |||
| outputs.append( | |||
| ( | |||
| o.format(**variables), | |||
| self.exchange_msg[var_def.group("var")][ | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value] | |||
| ) | |||
| ) | |||
| return outputs | |||
| def get_outputs_by_idx(self, idx, inner_idx=-1): | |||
| """Get outputs by idx.""" | |||
| outputs = self._get_outputs() | |||
| opt, opt_type = outputs[idx] | |||
| if opt_type == ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value: | |||
| return f"{opt}[{inner_idx}]" | |||
| return opt | |||
| def __call__(self) -> Tuple[List[str], List[str]]: | |||
| """ | |||
| Define parameter rewrite function. | |||
| Returns: | |||
| tuple[list[str], list[str]], init statement and construct statement. | |||
| """ | |||
| init_stats, call_stats = [], [] | |||
| precursor_node_var = [None, None] | |||
| for op_var, template in self._code_template.items(): | |||
| if ExchangeMessageKeywords.VariableScope.value.INPUTS.value not in self.exchange_msg[op_var]: | |||
| # It's possible inputs and precursor node both exists. | |||
| self.exchange_msg[op_var][ExchangeMessageKeywords.VariableScope.value.ARGS.value][ | |||
| precursor_node_var[0]] = precursor_node_var[1] | |||
| for tpl in template[TemplateKeywords.INIT.value]: | |||
| init_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl) | |||
| init_stats.append(init_stat) | |||
| for tpl in template[TemplateKeywords.CONSTRUCT.value]: | |||
| call_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl) | |||
| call_stats.append(call_stat) | |||
| precursor_node_var = op_var, self.exchange_msg[op_var].get( | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value) | |||
| return init_stats, call_stats | |||
| @staticmethod | |||
| def _rewrite(var, data, template: str) -> str: | |||
| """ | |||
| Backfill data into code template. | |||
| Args: | |||
| var (str): Current operation variable name. | |||
| data (dict): Data to be written. | |||
| template (str): Code template. | |||
| Returns: | |||
| str, single code line. | |||
| """ | |||
| rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]} | |||
| if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data: | |||
| rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join( | |||
| data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) | |||
| if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: | |||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) | |||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value]) | |||
| return template.format(**{ | |||
| k: str(rewrite_data[k]) for k in rewrite_data | |||
| }) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -209,3 +209,10 @@ def get_framework_type(model_path): | |||
| raise error | |||
| return framework_type | |||
| def reset_init_or_construct(template, variable_slot, new_data, scope): | |||
| """Reset init statement.""" | |||
| template[variable_slot][scope].clear() | |||
| template[variable_slot][scope] += new_data | |||
| return template | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -42,6 +42,47 @@ ONNX_MIN_VER = "1.8.0" | |||
| TF2ONNX_MIN_VER = "1.7.1" | |||
| ONNXRUNTIME_MIN_VER = "1.5.2" | |||
| @unique | |||
| class TemplateKeywords(Enum): | |||
| """Define keywords in template message.""" | |||
| INIT = "init" | |||
| CONSTRUCT = "construct" | |||
| @unique | |||
| class ExchangeMessageKeywords(Enum): | |||
| """Define keywords in exchange message.""" | |||
| METADATA = "metadata" | |||
| @unique | |||
| class MetadataScope(Enum): | |||
| """Define metadata scope keywords in exchange message.""" | |||
| SOURCE = "source" | |||
| OPERATION = "operation" | |||
| INPUTS = "inputs" | |||
| INPUTS_SHAPE = "inputs_shape" | |||
| OUTPUTS = "outputs" | |||
| OUTPUTS_SHAPE = "outputs_shape" | |||
| PRECURSOR = "precursor_nodes" | |||
| SUCCESSOR = "successor_nodes" | |||
| ATTRS = "attributes" | |||
| SCOPE = "scope" | |||
| @unique | |||
| class VariableScope(Enum): | |||
| """Define variable scope keywords in exchange message.""" | |||
| OPERATION = "operation" | |||
| VARIABLE_NAME = "variable_name" | |||
| OUTPUT_TYPE = "output_type" | |||
| TSR_TYPE = "tensor" | |||
| ARR_TYPE = "array" | |||
| INPUTS = "inputs" | |||
| ARGS = "args" | |||
| WEIGHTS = "weights" | |||
| TRAINABLE_PARAMS = "trainable_params" | |||
| BINARY_HEADER_PYTORCH_FILE = \ | |||
| b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' | |||
| TENSORFLOW_MODEL_SUFFIX = "pb" | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -18,8 +18,9 @@ __all__ = ["batch_add_nodes"] | |||
| import re | |||
| import copy | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords | |||
| from .generator import Generator, CodeStruct | |||
| from ..common.code_fragment import CodeFragment | |||
| def _tf_model_node_name_reformat(node, node_name): | |||
| @@ -34,7 +35,6 @@ def _tf_model_node_name_reformat(node, node_name): | |||
| str, re-formatted node name. | |||
| """ | |||
| scope_name = node.scope_name | |||
| new_name = None | |||
| regex = r"(?P<parent>.+/)(?P<op>\w+)" | |||
| match = re.match(regex, scope_name) | |||
| parent = match.group("parent") | |||
| @@ -79,6 +79,32 @@ def batch_add_nodes(graph_obj, mapper) -> Generator: | |||
| return generator_inst | |||
| def _supply_graph_info(node, external_inputs): | |||
| """ | |||
| Supply IR graph node info into metadata. | |||
| Args: | |||
| node (GraphNode): Graph node instance. | |||
| external_inputs (list[str]): External inputs in ONNX ir. | |||
| Returns: | |||
| dict, metadata. | |||
| """ | |||
| precursors = _combine_external_inputs_with_precursor_nodes(node, external_inputs) | |||
| return { | |||
| ExchangeMessageKeywords.MetadataScope.value.SOURCE.value: node.ir_node_name, | |||
| ExchangeMessageKeywords.MetadataScope.value.OPERATION.value: node.ir_node_operation, | |||
| ExchangeMessageKeywords.MetadataScope.value.SCOPE.value: node.scope_name, | |||
| ExchangeMessageKeywords.MetadataScope.value.INPUTS.value: node.ir_node_inputs, | |||
| ExchangeMessageKeywords.MetadataScope.value.INPUTS_SHAPE.value: node.input_shape, | |||
| ExchangeMessageKeywords.MetadataScope.value.OUTPUTS.value: node.ir_node_outputs, | |||
| ExchangeMessageKeywords.MetadataScope.value.OUTPUTS_SHAPE.value: node.output_shape, | |||
| ExchangeMessageKeywords.MetadataScope.value.PRECURSOR.value: precursors, | |||
| ExchangeMessageKeywords.MetadataScope.value.SUCCESSOR.value: node.ir_node_successor, | |||
| ExchangeMessageKeywords.MetadataScope.value.ATTRS.value: node.node_params, | |||
| } | |||
| def _convert_params(node, mapper): | |||
| """ | |||
| Call mapper to convert node's params from ONNX to MindSpore. | |||
| @@ -88,10 +114,8 @@ def _convert_params(node, mapper): | |||
| mapper (Mapper): The mapper instance which indicating conversion method. | |||
| Returns: | |||
| str, op name in MindSpore | |||
| dict, MindSpore parameters | |||
| dict, MindSpore settings | |||
| dict, weights of the node | |||
| tuple[str, dict, dict, dict], op name in MindSpore, MindSpore parameters, | |||
| MindSpore settings and weights of the node. | |||
| """ | |||
| params = copy.deepcopy(node.node_params) | |||
| params.update({"input_shape": node.input_shape, | |||
| @@ -109,3 +133,24 @@ def _convert_params(node, mapper): | |||
| return op_in_ms, ms_params, ms_settings, weights | |||
| return node.op_name, node.node_params, dict(), dict() | |||
| def _combine_external_inputs_with_precursor_nodes(node, external_inputs): | |||
| """ | |||
| User_provided_input_nodes. | |||
| Args: | |||
| node (OnnxGraphNode): Node instance. | |||
| external_inputs (list[str]): Inputs in onnx ir. | |||
| Returns: | |||
| list[str], precursor nodes list. | |||
| """ | |||
| inputs = set(node.ir_node_inputs) | |||
| to_be_added = list(inputs & set(external_inputs)) | |||
| precursor = node.ir_node_precursor | |||
| # Add external inputs to precursor as the order of its inputs. | |||
| for item in to_be_added: | |||
| node_idx = node.ir_node_inputs.index(item) | |||
| precursor.insert(node_idx, item) | |||
| return precursor | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -19,6 +19,7 @@ import json | |||
| import os | |||
| from typing import Dict | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| CONFIG_JSON = "onnx_to_ms.json" | |||
| OPERATION_TABLE = os.path.join( | |||
| @@ -36,6 +37,7 @@ GET_OP_NAME = "_operation_name_in_ms" | |||
| GET_OP_PARAMS = "_convert_params" | |||
| GET_OP_WEIGHTS = "_convert_trained_weights" | |||
| GET_OP_SETTINGS = "_convert_settings" | |||
| GET_OP_TEMPLATE = "_generate_snippet_template" | |||
| class Mapper(metaclass=abc.ABCMeta): | |||
| @@ -44,7 +46,7 @@ class Mapper(metaclass=abc.ABCMeta): | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| """Corresponding operation name in mindspore.""" | |||
| """Corresponding operation name in MindSpore.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| @@ -66,6 +68,11 @@ class Mapper(metaclass=abc.ABCMeta): | |||
| def convert(cls, op_name: str, params: Dict, weights: Dict = None): | |||
| """Convert third party operation's param into MindSpore operation.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| """Generate code template according to node info.""" | |||
| class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| """ONNX operation to MindSpore.""" | |||
| @@ -131,3 +138,36 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class BatchNormMapper(ONNXToMindSporeMapper): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,9 +14,9 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from ....common import utils | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | |||
| def _convert_padding(**kwargs): | |||
| @@ -83,7 +83,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| auto_pad = None | |||
| if params.get("auto_pad") is not None: | |||
| auto_pad = utils.convert_bytes_string_to_string(params.get("auto_pad")) | |||
| auto_pad = convert_bytes_string_to_string(params.get("auto_pad")) | |||
| # tmp tf translated ver. mapping | |||
| if isinstance(params.get('dilations'), list): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class DenseMapper(ONNXToMindSporeMapper): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class FlattenMapper(ONNXToMindSporeMapper): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class GlobalPoolMapper(ONNXToMindSporeMapper): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting, Tensor, get_dtype | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| class MatMulMapper(ONNXToMindSporeMapper): | |||
| @@ -43,3 +45,30 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| ref = t_name | |||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | |||
| dtype=get_dtype(tensor), reference=ref)) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| weight = list(weights.items())[0] | |||
| _, tensor = weight | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||
| f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||
| f"self.{{{variable_slot}}}_w)" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| def _padding_format_convert(padding: list): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class PoolMapper(ONNXToMindSporeMapper): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class ReLUMapper(ONNXToMindSporeMapper): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class SigmoidMapper(ONNXToMindSporeMapper): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class SoftmaxMapper(ONNXToMindSporeMapper): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting, Tensor, get_dtype | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype | |||
| class AddMapper(ONNXToMindSporeMapper): | |||
| @@ -43,3 +45,30 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| ref = t_name | |||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | |||
| dtype=get_dtype(tensor), reference=ref)) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| bias = list(weights.items())[0] | |||
| _, tensor = bias | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = " \ | |||
| f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||
| f"self.{{{variable_slot}}}_bias)" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,8 +14,10 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class ConcatMapper(ONNXToMindSporeMapper): | |||
| @@ -38,3 +40,14 @@ class ConcatMapper(ONNXToMindSporeMapper): | |||
| def _convert_settings(**kwargs): | |||
| input_type = InputType.LIST.value | |||
| return Setting(op_ipt_type=input_type) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| variable_slot = "var_0" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}},))" | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting, Tensor, get_dtype | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype | |||
| class MulMapper(ONNXToMindSporeMapper): | |||
| @@ -40,3 +42,29 @@ class MulMapper(ONNXToMindSporeMapper): | |||
| ref, tensor = list(weights.items())[0] | |||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | |||
| dtype=get_dtype(tensor), reference=ref)) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| weight = list(weights.items())[0] | |||
| _, tensor = weight | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| init_tensor = f"self.{{{variable_slot}}}_w = Tensor(np.random.uniform(0, 1, {tensor.shape})" \ | |||
| f".astype(np.{tensor.dtype}))" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||
| f"self.{{{variable_slot}}}_w)" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class ReduceMeanMapper(ONNXToMindSporeMapper): | |||
| @@ -42,3 +44,20 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): | |||
| else: | |||
| axis = tuple() | |||
| return Setting(op_extra_input={'axis': axis}) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| raw_params = kwargs.get("raw_params") | |||
| if raw_params.get('axes'): | |||
| axis = raw_params['axes'][0] if len(raw_params['axes']) == 1 else tuple(raw_params['axes']) | |||
| else: | |||
| axis = tuple() | |||
| variable_slot = "var_0" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {axis})" | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class ReshapeMapper(ONNXToMindSporeMapper): | |||
| @@ -52,3 +54,20 @@ class ReshapeMapper(ONNXToMindSporeMapper): | |||
| shape = [-1] | |||
| shape += list(weights.values())[0][1:].tolist() | |||
| return Setting(op_extra_input={"shape": tuple(shape)}) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| weights = kwargs.get("weights") | |||
| if len(weights) > 1: | |||
| raise ValueError("For reshape, `weights` length should equal to 1.") | |||
| shape = [-1] | |||
| shape += list(weights.values())[0][1:].tolist() | |||
| variable_slot = "var_0" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(shape)})" | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,9 +13,9 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from ....common import utils | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | |||
| class ResizeMapper(ONNXToMindSporeMapper): | |||
| @@ -26,11 +26,11 @@ class ResizeMapper(ONNXToMindSporeMapper): | |||
| params = kwargs.get("params") | |||
| onnx_coordinate_transform = params.get("coordinate_transformation_mode") | |||
| if onnx_coordinate_transform is not None: | |||
| onnx_coordinate_transform = utils.convert_bytes_string_to_string(onnx_coordinate_transform) | |||
| onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform) | |||
| interpolation_mode = params.get("mode") | |||
| if interpolation_mode is not None: | |||
| interpolation_mode = utils.convert_bytes_string_to_string(interpolation_mode) | |||
| interpolation_mode = convert_bytes_string_to_string(interpolation_mode) | |||
| # Define which MindSpore Resize operator to be used | |||
| if interpolation_mode == "linear": | |||
| @@ -54,7 +54,7 @@ class ResizeMapper(ONNXToMindSporeMapper): | |||
| onnx_coordinate_transform = params.get("coordinate_transformation_mode") | |||
| if onnx_coordinate_transform is not None: | |||
| onnx_coordinate_transform = utils.convert_bytes_string_to_string(onnx_coordinate_transform) | |||
| onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform) | |||
| if onnx_coordinate_transform == "align_corners" or "half_pixel" in onnx_coordinate_transform: | |||
| align_corners = True | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class SliceMapper(ONNXToMindSporeMapper): | |||
| @@ -41,3 +43,21 @@ class SliceMapper(ONNXToMindSporeMapper): | |||
| starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False) | |||
| return Setting(op_extra_input={"begin": tuple([i[0] for i in starts]), | |||
| "size": tuple(opt_shape)}) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| weights = list(kwargs.get("weights").values()) # start, end, axis | |||
| opt_shape = kwargs["raw_params"]["output_shape"] | |||
| if not weights: | |||
| raise ValueError("Cannot get required params from slice.") | |||
| starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False) | |||
| variable_slot = "var_0" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | |||
| f"{tuple([i[0] for i in starts])}, {tuple(opt_shape)})" | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class SplitMapper(ONNXToMindSporeMapper): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class TransposeMapper(ONNXToMindSporeMapper): | |||
| @@ -42,3 +44,17 @@ class TransposeMapper(ONNXToMindSporeMapper): | |||
| converted_params['input_perm'] = perm | |||
| return Setting(op_extra_input=converted_params) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| raw_params = kwargs.get("raw_params") | |||
| perm = raw_params["perm"] | |||
| variable_slot = "var_0" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(perm)})" | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,11 +17,13 @@ import abc | |||
| from collections import OrderedDict | |||
| from copy import deepcopy | |||
| from typing import List | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from ..common.code_fragment import CodeFragment | |||
| from ..constant import NodeType, InputType | |||
| from ..mapper.base import Mapper | |||
| from ...common.exceptions import NodeInputTypeNotSupportError | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment | |||
| from mindinsight.mindconverter.graph_based_converter.constant import NodeType, InputType | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import Mapper | |||
| from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError | |||
| class GraphParser(metaclass=abc.ABCMeta): | |||
| @@ -97,7 +99,6 @@ class Graph(BaseGraph, abc.ABC): | |||
| self.model = model | |||
| self._raw_input_nodes = kwargs.get("input_nodes") | |||
| self._raw_output_nodes = kwargs.get("output_nodes") | |||
| self.checkpoint = kwargs.get("checkpoint", None) | |||
| self._nodes_collection = OrderedDict() | |||
| self._nodes_record = dict() | |||
| self._shape_dict = dict() | |||
| @@ -107,6 +108,13 @@ class Graph(BaseGraph, abc.ABC): | |||
| self._input_shape = dict() | |||
| self._is_multi_opt_graph = False | |||
| @property | |||
| def user_provided_input_nodes(self) -> List[str]: | |||
| """User provided input_nodes in CLI.""" | |||
| if not isinstance(self._raw_input_nodes, list): | |||
| return [self._raw_input_nodes] | |||
| return self._raw_input_nodes | |||
| def get_input_shape(self, name): | |||
| """ | |||
| Get node input shape. | |||
| @@ -285,9 +293,9 @@ class GraphNode(abc.ABC): | |||
| self.successor_nodes = [] | |||
| # Control dependency. | |||
| self._deleted_in_edge = 0 | |||
| # Source node in pytorch. | |||
| self._src_node = str(node) if node else None | |||
| # Original operation name in pytorch. | |||
| # Source node in ONNX. | |||
| self._src_node = node if node else None | |||
| # Original operation name in ONNX. | |||
| self._op_name = None | |||
| self._op_params = dict() | |||
| self._scope_name = None | |||
| @@ -311,6 +319,40 @@ class GraphNode(abc.ABC): | |||
| # Is in multi output graph. | |||
| self._is_in_multi_opt_graph = False | |||
| @property | |||
| def ir_node_name(self): | |||
| """Getter of ir node's name.""" | |||
| return self._src_node.name | |||
| @property | |||
| def ir_node_operation(self): | |||
| """Getter of ir node's operation.""" | |||
| return self._src_node.op_type | |||
| @property | |||
| def ir_node_inputs(self): | |||
| """Getter of ir node's inputs.""" | |||
| return list(self._src_node.input_name_list) | |||
| @property | |||
| def ir_node_outputs(self): | |||
| """Getter of ir node's outputs.""" | |||
| return list(self._src_node.output_name_list) | |||
| @property | |||
| def ir_node_precursor(self): | |||
| """Getter of ir node's precursor.""" | |||
| return [ | |||
| v.name for _, v in self._src_node.precursor_onnx_node_dict.items() | |||
| ] | |||
| @property | |||
| def ir_node_successor(self): | |||
| """Getter of ir node's successor.""" | |||
| return [ | |||
| v.name for _, v in self._src_node.successor_onnx_node_dict.items() | |||
| ] | |||
| @property | |||
| def weight(self): | |||
| return self._weight | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Unit test for mindconverter.graph_based_converter.common interface.""" | |||
| @@ -0,0 +1,145 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Test fragment.""" | |||
| from unittest import TestCase | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment as Fragment | |||
| class TestFragment(TestCase): | |||
| """Tester of fragment.""" | |||
| def test_matmul(self): | |||
| """Test matmul like operation's template.""" | |||
| template = { | |||
| 'var_0': { | |||
| 'init': [ | |||
| 'self.{var_0} = nn.MatMul()', | |||
| 'self.{var_0}_w = Tensor(np.random.rand(*(2048, 1000)).astype(np.float32))' | |||
| ], | |||
| 'construct': ['opt_{var_0} = self.{var_0}({inputs},self.{var_0}_w)'] | |||
| } | |||
| } | |||
| rewrite_data = { | |||
| 'var_0': { | |||
| 'operation': 'nn.MatMul', | |||
| 'output_type': 'tensor', | |||
| 'variable_name': "matmul", 'inputs': ["x"], 'args': {}, | |||
| 'weights': {}, | |||
| 'trainable_params': {} | |||
| }, | |||
| 'metadata': { | |||
| 'source': 'probs/MatMul', 'operation': 'MatMul', 'scope': 'Model/MatMul', | |||
| 'inputs': ['avg_pool/Mean:0', 'probs/MatMul/ReadVariableOp:0'], | |||
| 'inputs_shape': (1, 2048), 'outputs': ['probs/MatMul:0'], 'outputs_shape': [1, 1000], | |||
| 'precursor_nodes': ['avg_pool/Mean'], 'successor_nodes': ['probs/BiasAdd'], | |||
| 'attributes': {} | |||
| } | |||
| } | |||
| fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], | |||
| outputs_mapping=((0, 0),)) | |||
| code = fragment() | |||
| init = code[0] | |||
| construct = code[1] | |||
| self.assertEqual(init, ['self.matmul = nn.MatMul()', | |||
| 'self.matmul_w = Tensor(np.random.rand(*(2048, 1000)).astype(np.float32))']) | |||
| self.assertEqual(construct, ['opt_matmul = self.matmul(x,self.matmul_w)']) | |||
| self.assertEqual(fragment.get_outputs_by_idx(0), "opt_matmul") | |||
| def test_biasadd(self): | |||
| """Test biasadd like operation's template.""" | |||
| template = { | |||
| 'var_0': { | |||
| 'init': [ | |||
| 'self.{var_0} = P.TensorAdd()', | |||
| 'self.{var_0}_bias = Tensor(np.random.rand(*(1000,)).astype(np.float32))' | |||
| ], | |||
| 'construct': ['opt_{var_0} = self.{var_0}({inputs},self.{var_0}_bias)'] | |||
| } | |||
| } | |||
| rewrite_data = { | |||
| 'var_0': { | |||
| 'operation': 'P.TensorAdd', | |||
| 'output_type': 'tensor', | |||
| 'variable_name': "add", 'inputs': ["x"], 'args': {}, 'weights': {}, | |||
| 'trainable_params': {} | |||
| }, | |||
| 'metadata': { | |||
| 'source': 'probs/BiasAdd', 'operation': 'Add', 'scope': 'Model/Add', | |||
| 'inputs': ['probs/MatMul:0', 'probs/BiasAdd/ReadVariableOp:0'], 'inputs_shape': (1, 1000), | |||
| 'outputs': ['probs/BiasAdd:0'], 'outputs_shape': [1, 1000], 'precursor_nodes': ['probs/MatMul'], | |||
| 'successor_nodes': ['probs/Softmax'], 'attributes': {} | |||
| } | |||
| } | |||
| fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], | |||
| outputs_mapping=((0, 0),)) | |||
| code = fragment() | |||
| init = code[0] | |||
| construct = code[1] | |||
| self.assertEqual(init, ['self.add = P.TensorAdd()', | |||
| 'self.add_bias = Tensor(np.random.rand(*(1000,)).astype(np.float32))']) | |||
| self.assertEqual(construct, ['opt_add = self.add(x,self.add_bias)']) | |||
| self.assertEqual(fragment.get_outputs_by_idx(0), "opt_add") | |||
| def test_transpose(self): | |||
| """Test transpose like operation's template.""" | |||
| template = { | |||
| 'var_0': { | |||
| 'init': ['self.{var_0} = P.Transpose()'], | |||
| 'construct': ['opt_{var_0} = self.{var_0}({inputs}, (0, 3, 1, 2))'] | |||
| } | |||
| } | |||
| rewrite_data = { | |||
| 'var_0': { | |||
| 'operation': 'P.Transpose', | |||
| 'output_type': 'tensor', | |||
| 'variable_name': "transpose", 'inputs': ["x"], 'args': {}, 'weights': {}, | |||
| 'trainable_params': {} | |||
| } | |||
| } | |||
| fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], | |||
| outputs_mapping=((0, 0),)) | |||
| code = fragment() | |||
| init = code[0] | |||
| construct = code[1] | |||
| self.assertEqual(init, ['self.transpose = P.Transpose()']) | |||
| self.assertEqual(construct, ['opt_transpose = self.transpose(x, (0, 3, 1, 2))']) | |||
| self.assertEqual(fragment.get_outputs_by_idx(0), "opt_transpose") | |||
| def test_split(self): | |||
| """Test split like operation's template.""" | |||
| template = { | |||
| 'var_0': { | |||
| 'init': ['self.{var_0} = P.Split(axis={axis}, output_num={output_num})'], | |||
| 'construct': ['opt_{var_0} = self.{var_0}({inputs})'] | |||
| } | |||
| } | |||
| rewrite_data = { | |||
| 'var_0': { | |||
| 'operation': 'P.Split', | |||
| 'variable_name': "split", | |||
| 'output_type': 'array', | |||
| 'inputs': ["x"], | |||
| 'args': {"axis": 1, "output_num": 2}, 'weights': {}, | |||
| 'trainable_params': {} | |||
| } | |||
| } | |||
| fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], | |||
| outputs_mapping=((0, 0),)) | |||
| code = fragment() | |||
| init = code[0] | |||
| construct = code[1] | |||
| self.assertEqual(init, ['self.split = P.Split(axis=1, output_num=2)']) | |||
| self.assertEqual(construct, ['opt_split = self.split(x)']) | |||
| self.assertEqual(fragment.get_outputs_by_idx(0, 1), 'opt_split[1]') | |||