diff --git a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py index 2a04a570..e0a228bf 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py +++ b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,26 +14,10 @@ # ============================================================================== """Define CodeLine object.""" import abc +import re +from typing import List, Tuple - -class TrainableParams: - """Trainable parameters.""" - - def __init__(self, shape, dtype, reference): - self.param_name = None - self.shape = shape - self.dtype = dtype - self.reference = reference # Weight name in global npy. - - -class CodeSetting: - """Code generation settings.""" - - def __init__(self): - self.output_vars_suffix = [] - self.operation_input_type = None # Construct input type, tensor or list. - self.operation_extra_input = dict() # `values` in original setting dict. - self.operation_extra_tensor = None # For `MatMul`, `BiasAdd` op, need a tensor +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords class Fragment(abc.ABC): @@ -222,3 +206,144 @@ class ModuleFragment(Fragment): super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, input_shape=input_shape, output_shape=output_shape, settings=settings) + + +class NewFragment: + """ + Fragment definition for MindSpore code generation. + + Args: + data_entity (dict): Required data by operations. The format of `data_entity` is as follow: + { + "var1": { + "metadata": { # ONNX Metadata + "operation": "Conv2d", + "source": "conv_pw_13/Conv2D", + "attributes": { + # Put original onnx attributes here. + } + }, + "variable_name": None, + "inputs": [], + "output_type": "tensor" | "array", + "args": {"in_channels": 768, "out_channels": 1024}, + "trainable_params": {"weight": "Parameter(Tensor(GLOBAL_W[NAME]))"} + }, + "var2": { + "variable_name": "pad", + "args": {"padding": [0, 1, 1, 0], "mode": "SAME"} + } + } + code_template (dict): Code template generated by mapper. The format of `code_template` is as follow: + { + "var1": { + "init": [ + "self.{var1} = nn.Conv2d(in_channels={in_channels})", + "self.{var1}.weight = {weight}" + ], + "construct": [ + "opt_{var1} = self.{var1}({inputs}[, extra])" + ] + }, + "var2": { + "init": [ + "self.{var2} = nn.Pad(padding={padding}, mode={mode})" + ], + "construct": [ + "opt_{var2} = self.{var2}(opt_{var1}[, extra])" + ] + } + } + outputs (list[str]): Outputs name slot list. + outputs_mapping (tuple): Outputs index mapping between ir node and MindSpore operation. + """ + + def __init__(self, data_entity: dict, code_template: dict, outputs: List[str], outputs_mapping): + self.exchange_msg = data_entity + self._code_template = code_template + self.inputs = [] + self._outputs = outputs + self.outputs_mapping = outputs_mapping + self.format_args = dict() + + def _get_outputs(self): + """ + Get outputs of the code snippet. + + Returns: + list[str], outputs of current code block. + """ + outputs = [] + variables = { + k: self.exchange_msg[k][ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value] + for k in self.exchange_msg if k != ExchangeMessageKeywords.METADATA.value + } + for o in self._outputs: + extractor = r".*\{(?P.+)\}.*" + var_def = re.match(extractor, o) + if not var_def: + raise ValueError(f"Output variable name {o} is illegal.") + outputs.append( + ( + o.format(**variables), + self.exchange_msg[var_def.group("var")][ + ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value] + ) + ) + return outputs + + def get_outputs_by_idx(self, idx, inner_idx=-1): + """Get outputs by idx.""" + outputs = self._get_outputs() + opt, opt_type = outputs[idx] + if opt_type == ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value: + return f"{opt}[{inner_idx}]" + return opt + + def __call__(self) -> Tuple[List[str], List[str]]: + """ + Define parameter rewrite function. + + Returns: + tuple[list[str], list[str]], init statement and construct statement. + """ + init_stats, call_stats = [], [] + precursor_node_var = [None, None] + for op_var, template in self._code_template.items(): + if ExchangeMessageKeywords.VariableScope.value.INPUTS.value not in self.exchange_msg[op_var]: + # It's possible inputs and precursor node both exists. + self.exchange_msg[op_var][ExchangeMessageKeywords.VariableScope.value.ARGS.value][ + precursor_node_var[0]] = precursor_node_var[1] + for tpl in template[TemplateKeywords.INIT.value]: + init_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl) + init_stats.append(init_stat) + for tpl in template[TemplateKeywords.CONSTRUCT.value]: + call_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl) + call_stats.append(call_stat) + precursor_node_var = op_var, self.exchange_msg[op_var].get( + ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value) + return init_stats, call_stats + + @staticmethod + def _rewrite(var, data, template: str) -> str: + """ + Backfill data into code template. + + Args: + var (str): Current operation variable name. + data (dict): Data to be written. + template (str): Code template. + + Returns: + str, single code line. + """ + rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]} + if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data: + rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join( + data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) + if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: + rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) + rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value]) + return template.format(**{ + k: str(rewrite_data[k]) for k in rewrite_data + }) diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index 4849d860..b9e86f6e 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -209,3 +209,10 @@ def get_framework_type(model_path): raise error return framework_type + + +def reset_init_or_construct(template, variable_slot, new_data, scope): + """Reset init statement.""" + template[variable_slot][scope].clear() + template[variable_slot][scope] += new_data + return template diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 0dafb94f..9041a940 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,6 +42,47 @@ ONNX_MIN_VER = "1.8.0" TF2ONNX_MIN_VER = "1.7.1" ONNXRUNTIME_MIN_VER = "1.5.2" + +@unique +class TemplateKeywords(Enum): + """Define keywords in template message.""" + INIT = "init" + CONSTRUCT = "construct" + + +@unique +class ExchangeMessageKeywords(Enum): + """Define keywords in exchange message.""" + METADATA = "metadata" + + @unique + class MetadataScope(Enum): + """Define metadata scope keywords in exchange message.""" + SOURCE = "source" + OPERATION = "operation" + INPUTS = "inputs" + INPUTS_SHAPE = "inputs_shape" + OUTPUTS = "outputs" + OUTPUTS_SHAPE = "outputs_shape" + PRECURSOR = "precursor_nodes" + SUCCESSOR = "successor_nodes" + ATTRS = "attributes" + SCOPE = "scope" + + @unique + class VariableScope(Enum): + """Define variable scope keywords in exchange message.""" + OPERATION = "operation" + VARIABLE_NAME = "variable_name" + OUTPUT_TYPE = "output_type" + TSR_TYPE = "tensor" + ARR_TYPE = "array" + INPUTS = "inputs" + ARGS = "args" + WEIGHTS = "weights" + TRAINABLE_PARAMS = "trainable_params" + + BINARY_HEADER_PYTORCH_FILE = \ b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' TENSORFLOW_MODEL_SUFFIX = "pb" diff --git a/mindinsight/mindconverter/graph_based_converter/generator/__init__.py b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py index 483a93b1..517cc6b2 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,9 @@ __all__ = ["batch_add_nodes"] import re import copy +from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords from .generator import Generator, CodeStruct -from ..common.code_fragment import CodeFragment def _tf_model_node_name_reformat(node, node_name): @@ -34,7 +35,6 @@ def _tf_model_node_name_reformat(node, node_name): str, re-formatted node name. """ scope_name = node.scope_name - new_name = None regex = r"(?P.+/)(?P\w+)" match = re.match(regex, scope_name) parent = match.group("parent") @@ -79,6 +79,32 @@ def batch_add_nodes(graph_obj, mapper) -> Generator: return generator_inst +def _supply_graph_info(node, external_inputs): + """ + Supply IR graph node info into metadata. + + Args: + node (GraphNode): Graph node instance. + external_inputs (list[str]): External inputs in ONNX ir. + + Returns: + dict, metadata. + """ + precursors = _combine_external_inputs_with_precursor_nodes(node, external_inputs) + return { + ExchangeMessageKeywords.MetadataScope.value.SOURCE.value: node.ir_node_name, + ExchangeMessageKeywords.MetadataScope.value.OPERATION.value: node.ir_node_operation, + ExchangeMessageKeywords.MetadataScope.value.SCOPE.value: node.scope_name, + ExchangeMessageKeywords.MetadataScope.value.INPUTS.value: node.ir_node_inputs, + ExchangeMessageKeywords.MetadataScope.value.INPUTS_SHAPE.value: node.input_shape, + ExchangeMessageKeywords.MetadataScope.value.OUTPUTS.value: node.ir_node_outputs, + ExchangeMessageKeywords.MetadataScope.value.OUTPUTS_SHAPE.value: node.output_shape, + ExchangeMessageKeywords.MetadataScope.value.PRECURSOR.value: precursors, + ExchangeMessageKeywords.MetadataScope.value.SUCCESSOR.value: node.ir_node_successor, + ExchangeMessageKeywords.MetadataScope.value.ATTRS.value: node.node_params, + } + + def _convert_params(node, mapper): """ Call mapper to convert node's params from ONNX to MindSpore. @@ -88,10 +114,8 @@ def _convert_params(node, mapper): mapper (Mapper): The mapper instance which indicating conversion method. Returns: - str, op name in MindSpore - dict, MindSpore parameters - dict, MindSpore settings - dict, weights of the node + tuple[str, dict, dict, dict], op name in MindSpore, MindSpore parameters, + MindSpore settings and weights of the node. """ params = copy.deepcopy(node.node_params) params.update({"input_shape": node.input_shape, @@ -109,3 +133,24 @@ def _convert_params(node, mapper): return op_in_ms, ms_params, ms_settings, weights return node.op_name, node.node_params, dict(), dict() + + +def _combine_external_inputs_with_precursor_nodes(node, external_inputs): + """ + User_provided_input_nodes. + + Args: + node (OnnxGraphNode): Node instance. + external_inputs (list[str]): Inputs in onnx ir. + + Returns: + list[str], precursor nodes list. + """ + inputs = set(node.ir_node_inputs) + to_be_added = list(inputs & set(external_inputs)) + precursor = node.ir_node_precursor + # Add external inputs to precursor as the order of its inputs. + for item in to_be_added: + node_idx = node.ir_node_inputs.index(item) + precursor.insert(node_idx, item) + return precursor diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/base.py b/mindinsight/mindconverter/graph_based_converter/mapper/base.py index e576e2e9..8bd6c96b 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/base.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/base.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import json import os from typing import Dict from mindinsight.mindconverter.common.log import logger as log +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords CONFIG_JSON = "onnx_to_ms.json" OPERATION_TABLE = os.path.join( @@ -36,6 +37,7 @@ GET_OP_NAME = "_operation_name_in_ms" GET_OP_PARAMS = "_convert_params" GET_OP_WEIGHTS = "_convert_trained_weights" GET_OP_SETTINGS = "_convert_settings" +GET_OP_TEMPLATE = "_generate_snippet_template" class Mapper(metaclass=abc.ABCMeta): @@ -44,7 +46,7 @@ class Mapper(metaclass=abc.ABCMeta): @staticmethod @abc.abstractmethod def _operation_name_in_ms(*args, **kwargs): - """Corresponding operation name in mindspore.""" + """Corresponding operation name in MindSpore.""" @staticmethod @abc.abstractmethod @@ -66,6 +68,11 @@ class Mapper(metaclass=abc.ABCMeta): def convert(cls, op_name: str, params: Dict, weights: Dict = None): """Convert third party operation's param into MindSpore operation.""" + @staticmethod + @abc.abstractmethod + def _generate_snippet_template(**kwargs): + """Generate code template according to node info.""" + class ONNXToMindSporeMapper(Mapper, abc.ABC): """ONNX operation to MindSpore.""" @@ -131,3 +138,36 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): @staticmethod def _convert_settings(**kwargs): raise NotImplementedError + + @staticmethod + def _generate_snippet_template(**kwargs): + op = kwargs.get("operation") + args = kwargs.get("converted_params") + weights = kwargs.get("weights") + if not op: + raise ValueError("Can not get MindSpore operation name.") + variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" + template = { + variable_slot: { + TemplateKeywords.INIT.value: [init_template], + TemplateKeywords.CONSTRUCT.value: [construct_template] + } + } + exchange_msg = { + variable_slot: { + ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, + ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, + ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: + ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, + ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], + ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, + ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, + ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} + } + } + outputs_list = [f"opt_{{{variable_slot}}}"] + outputs_mapping = ((0, 0),) + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py index 5d02b9b2..e7c51548 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class BatchNormMapper(ONNXToMindSporeMapper): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py index 440393a1..1cb09a61 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ # ============================================================================== """Mapper module.""" import numpy as np -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting -from ....common import utils +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string def _convert_padding(**kwargs): @@ -83,7 +83,7 @@ class ConvMapper(ONNXToMindSporeMapper): auto_pad = None if params.get("auto_pad") is not None: - auto_pad = utils.convert_bytes_string_to_string(params.get("auto_pad")) + auto_pad = convert_bytes_string_to_string(params.get("auto_pad")) # tmp tf translated ver. mapping if isinstance(params.get('dilations'), list): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py index 2f4eb387..b11240d5 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class DenseMapper(ONNXToMindSporeMapper): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py index 024cf499..2df0610e 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class FlattenMapper(ONNXToMindSporeMapper): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py index 29bc8550..4a6196dc 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class GlobalPoolMapper(ONNXToMindSporeMapper): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py index 603a4fd8..ce9abe0f 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting, Tensor, get_dtype +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype +from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords class MatMulMapper(ONNXToMindSporeMapper): @@ -43,3 +45,30 @@ class MatMulMapper(ONNXToMindSporeMapper): ref = t_name return Setting(op_extra_tensor=Tensor(shape=tensor.shape, dtype=get_dtype(tensor), reference=ref)) + + @staticmethod + def _generate_snippet_template(**kwargs): + template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( + **kwargs) + op = kwargs.get("operation") + args = kwargs.get("converted_params") + weights = kwargs.get("weights") + if not weights: + return template, exchange_msg, outputs_list, outputs_mapping + + weight = list(weights.items())[0] + _, tensor = weight + + variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" + init_tensor = f"self.{{{variable_slot}}}_w = " \ + f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ + f"self.{{{variable_slot}}}_w)" + template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], + TemplateKeywords.INIT.value) + template = reset_init_or_construct(template, variable_slot, [construct_template], + TemplateKeywords.CONSTRUCT.value) + + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py index bb1a2fb9..d99fcff5 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting def _padding_format_convert(padding: list): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py index b33ed715..f3baa01d 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class PoolMapper(ONNXToMindSporeMapper): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py index e89052ca..3ed70996 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class ReLUMapper(ONNXToMindSporeMapper): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/sigmoid_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/sigmoid_mapper.py index 6113e27d..e2f769a1 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/sigmoid_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/sigmoid_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class SigmoidMapper(ONNXToMindSporeMapper): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py index be029109..949a6e85 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class SoftmaxMapper(ONNXToMindSporeMapper): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py index 83808984..b3600b4a 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting, Tensor, get_dtype +from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype class AddMapper(ONNXToMindSporeMapper): @@ -43,3 +45,30 @@ class AddMapper(ONNXToMindSporeMapper): ref = t_name return Setting(op_extra_tensor=Tensor(shape=tensor.shape, dtype=get_dtype(tensor), reference=ref)) + + @staticmethod + def _generate_snippet_template(**kwargs): + template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( + **kwargs) + op = kwargs.get("operation") + args = kwargs.get("converted_params") + weights = kwargs.get("weights") + if not weights: + return template, exchange_msg, outputs_list, outputs_mapping + + bias = list(weights.items())[0] + _, tensor = bias + + variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" + init_tensor = f"self.{{{variable_slot}}}_bias = " \ + f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ + f"self.{{{variable_slot}}}_bias)" + template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], + TemplateKeywords.INIT.value) + template = reset_init_or_construct(template, variable_slot, [construct_template], + TemplateKeywords.CONSTRUCT.value) + + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py index eb1205f9..a22547d8 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +14,10 @@ # ============================================================================== """Mapper module.""" from mindinsight.mindconverter.graph_based_converter.constant import InputType -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class ConcatMapper(ONNXToMindSporeMapper): @@ -38,3 +40,14 @@ class ConcatMapper(ONNXToMindSporeMapper): def _convert_settings(**kwargs): input_type = InputType.LIST.value return Setting(op_ipt_type=input_type) + + @staticmethod + def _generate_snippet_template(**kwargs): + template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( + **kwargs) + variable_slot = "var_0" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}},))" + template = reset_init_or_construct(template, variable_slot, [construct_template], + TemplateKeywords.CONSTRUCT.value) + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py index 0501b714..5b3c5d1b 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting, Tensor, get_dtype +from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype class MulMapper(ONNXToMindSporeMapper): @@ -40,3 +42,29 @@ class MulMapper(ONNXToMindSporeMapper): ref, tensor = list(weights.items())[0] return Setting(op_extra_tensor=Tensor(shape=tensor.shape, dtype=get_dtype(tensor), reference=ref)) + + @staticmethod + def _generate_snippet_template(**kwargs): + template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( + **kwargs) + op = kwargs.get("operation") + args = kwargs.get("converted_params") + weights = kwargs.get("weights") + if not weights: + return template, exchange_msg, outputs_list, outputs_mapping + + weight = list(weights.items())[0] + _, tensor = weight + + variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" + init_tensor = f"self.{{{variable_slot}}}_w = Tensor(np.random.uniform(0, 1, {tensor.shape})" \ + f".astype(np.{tensor.dtype}))" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ + f"self.{{{variable_slot}}}_w)" + template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], + TemplateKeywords.INIT.value) + template = reset_init_or_construct(template, variable_slot, [construct_template], + TemplateKeywords.CONSTRUCT.value) + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py index 239d07a6..edbbd50a 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class ReduceMeanMapper(ONNXToMindSporeMapper): @@ -42,3 +44,20 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): else: axis = tuple() return Setting(op_extra_input={'axis': axis}) + + @staticmethod + def _generate_snippet_template(**kwargs): + template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( + **kwargs) + raw_params = kwargs.get("raw_params") + if raw_params.get('axes'): + axis = raw_params['axes'][0] if len(raw_params['axes']) == 1 else tuple(raw_params['axes']) + else: + axis = tuple() + variable_slot = "var_0" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {axis})" + template = reset_init_or_construct(template, variable_slot, [construct_template], + TemplateKeywords.CONSTRUCT.value) + + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py index 96daf770..30d0e555 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class ReshapeMapper(ONNXToMindSporeMapper): @@ -52,3 +54,20 @@ class ReshapeMapper(ONNXToMindSporeMapper): shape = [-1] shape += list(weights.values())[0][1:].tolist() return Setting(op_extra_input={"shape": tuple(shape)}) + + @staticmethod + def _generate_snippet_template(**kwargs): + template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( + **kwargs) + weights = kwargs.get("weights") + if len(weights) > 1: + raise ValueError("For reshape, `weights` length should equal to 1.") + shape = [-1] + shape += list(weights.values())[0][1:].tolist() + variable_slot = "var_0" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(shape)})" + template = reset_init_or_construct(template, variable_slot, [construct_template], + TemplateKeywords.CONSTRUCT.value) + + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py index 8b587029..9aa7a3cf 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,9 +13,9 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting -from ....common import utils +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string class ResizeMapper(ONNXToMindSporeMapper): @@ -26,11 +26,11 @@ class ResizeMapper(ONNXToMindSporeMapper): params = kwargs.get("params") onnx_coordinate_transform = params.get("coordinate_transformation_mode") if onnx_coordinate_transform is not None: - onnx_coordinate_transform = utils.convert_bytes_string_to_string(onnx_coordinate_transform) + onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform) interpolation_mode = params.get("mode") if interpolation_mode is not None: - interpolation_mode = utils.convert_bytes_string_to_string(interpolation_mode) + interpolation_mode = convert_bytes_string_to_string(interpolation_mode) # Define which MindSpore Resize operator to be used if interpolation_mode == "linear": @@ -54,7 +54,7 @@ class ResizeMapper(ONNXToMindSporeMapper): onnx_coordinate_transform = params.get("coordinate_transformation_mode") if onnx_coordinate_transform is not None: - onnx_coordinate_transform = utils.convert_bytes_string_to_string(onnx_coordinate_transform) + onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform) if onnx_coordinate_transform == "align_corners" or "half_pixel" in onnx_coordinate_transform: align_corners = True diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py index 4fa870e6..24b27943 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class SliceMapper(ONNXToMindSporeMapper): @@ -41,3 +43,21 @@ class SliceMapper(ONNXToMindSporeMapper): starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False) return Setting(op_extra_input={"begin": tuple([i[0] for i in starts]), "size": tuple(opt_shape)}) + + @staticmethod + def _generate_snippet_template(**kwargs): + template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( + **kwargs) + weights = list(kwargs.get("weights").values()) # start, end, axis + opt_shape = kwargs["raw_params"]["output_shape"] + if not weights: + raise ValueError("Cannot get required params from slice.") + starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False) + variable_slot = "var_0" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ + f"{tuple([i[0] for i in starts])}, {tuple(opt_shape)})" + template = reset_init_or_construct(template, variable_slot, [construct_template], + TemplateKeywords.CONSTRUCT.value) + + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py index bec85ff4..40b710e4 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class SplitMapper(ONNXToMindSporeMapper): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py index d294d9d1..5b0999fa 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from ...base import ONNXToMindSporeMapper -from ...gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class TransposeMapper(ONNXToMindSporeMapper): @@ -42,3 +44,17 @@ class TransposeMapper(ONNXToMindSporeMapper): converted_params['input_perm'] = perm return Setting(op_extra_input=converted_params) + + @staticmethod + def _generate_snippet_template(**kwargs): + template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( + **kwargs) + raw_params = kwargs.get("raw_params") + perm = raw_params["perm"] + variable_slot = "var_0" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(perm)})" + template = reset_init_or_construct(template, variable_slot, [construct_template], + TemplateKeywords.CONSTRUCT.value) + + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index e29d7960..af6e5c69 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,11 +17,13 @@ import abc from collections import OrderedDict from copy import deepcopy +from typing import List + from mindinsight.mindconverter.common.log import logger as log -from ..common.code_fragment import CodeFragment -from ..constant import NodeType, InputType -from ..mapper.base import Mapper -from ...common.exceptions import NodeInputTypeNotSupportError +from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment +from mindinsight.mindconverter.graph_based_converter.constant import NodeType, InputType +from mindinsight.mindconverter.graph_based_converter.mapper.base import Mapper +from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError class GraphParser(metaclass=abc.ABCMeta): @@ -97,7 +99,6 @@ class Graph(BaseGraph, abc.ABC): self.model = model self._raw_input_nodes = kwargs.get("input_nodes") self._raw_output_nodes = kwargs.get("output_nodes") - self.checkpoint = kwargs.get("checkpoint", None) self._nodes_collection = OrderedDict() self._nodes_record = dict() self._shape_dict = dict() @@ -107,6 +108,13 @@ class Graph(BaseGraph, abc.ABC): self._input_shape = dict() self._is_multi_opt_graph = False + @property + def user_provided_input_nodes(self) -> List[str]: + """User provided input_nodes in CLI.""" + if not isinstance(self._raw_input_nodes, list): + return [self._raw_input_nodes] + return self._raw_input_nodes + def get_input_shape(self, name): """ Get node input shape. @@ -285,9 +293,9 @@ class GraphNode(abc.ABC): self.successor_nodes = [] # Control dependency. self._deleted_in_edge = 0 - # Source node in pytorch. - self._src_node = str(node) if node else None - # Original operation name in pytorch. + # Source node in ONNX. + self._src_node = node if node else None + # Original operation name in ONNX. self._op_name = None self._op_params = dict() self._scope_name = None @@ -311,6 +319,40 @@ class GraphNode(abc.ABC): # Is in multi output graph. self._is_in_multi_opt_graph = False + @property + def ir_node_name(self): + """Getter of ir node's name.""" + return self._src_node.name + + @property + def ir_node_operation(self): + """Getter of ir node's operation.""" + return self._src_node.op_type + + @property + def ir_node_inputs(self): + """Getter of ir node's inputs.""" + return list(self._src_node.input_name_list) + + @property + def ir_node_outputs(self): + """Getter of ir node's outputs.""" + return list(self._src_node.output_name_list) + + @property + def ir_node_precursor(self): + """Getter of ir node's precursor.""" + return [ + v.name for _, v in self._src_node.precursor_onnx_node_dict.items() + ] + + @property + def ir_node_successor(self): + """Getter of ir node's successor.""" + return [ + v.name for _, v in self._src_node.successor_onnx_node_dict.items() + ] + @property def weight(self): return self._weight diff --git a/tests/ut/mindconverter/graph_based_converter/common/__init__.py b/tests/ut/mindconverter/graph_based_converter/common/__init__.py new file mode 100644 index 00000000..15a14248 --- /dev/null +++ b/tests/ut/mindconverter/graph_based_converter/common/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for mindconverter.graph_based_converter.common interface.""" diff --git a/tests/ut/mindconverter/graph_based_converter/common/test_fragment.py b/tests/ut/mindconverter/graph_based_converter/common/test_fragment.py new file mode 100644 index 00000000..de020531 --- /dev/null +++ b/tests/ut/mindconverter/graph_based_converter/common/test_fragment.py @@ -0,0 +1,145 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test fragment.""" +from unittest import TestCase +from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment as Fragment + + +class TestFragment(TestCase): + """Tester of fragment.""" + + def test_matmul(self): + """Test matmul like operation's template.""" + template = { + 'var_0': { + 'init': [ + 'self.{var_0} = nn.MatMul()', + 'self.{var_0}_w = Tensor(np.random.rand(*(2048, 1000)).astype(np.float32))' + ], + 'construct': ['opt_{var_0} = self.{var_0}({inputs},self.{var_0}_w)'] + } + } + rewrite_data = { + 'var_0': { + 'operation': 'nn.MatMul', + 'output_type': 'tensor', + 'variable_name': "matmul", 'inputs': ["x"], 'args': {}, + 'weights': {}, + 'trainable_params': {} + }, + 'metadata': { + 'source': 'probs/MatMul', 'operation': 'MatMul', 'scope': 'Model/MatMul', + 'inputs': ['avg_pool/Mean:0', 'probs/MatMul/ReadVariableOp:0'], + 'inputs_shape': (1, 2048), 'outputs': ['probs/MatMul:0'], 'outputs_shape': [1, 1000], + 'precursor_nodes': ['avg_pool/Mean'], 'successor_nodes': ['probs/BiasAdd'], + 'attributes': {} + } + } + fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], + outputs_mapping=((0, 0),)) + code = fragment() + init = code[0] + construct = code[1] + self.assertEqual(init, ['self.matmul = nn.MatMul()', + 'self.matmul_w = Tensor(np.random.rand(*(2048, 1000)).astype(np.float32))']) + self.assertEqual(construct, ['opt_matmul = self.matmul(x,self.matmul_w)']) + self.assertEqual(fragment.get_outputs_by_idx(0), "opt_matmul") + + def test_biasadd(self): + """Test biasadd like operation's template.""" + template = { + 'var_0': { + 'init': [ + 'self.{var_0} = P.TensorAdd()', + 'self.{var_0}_bias = Tensor(np.random.rand(*(1000,)).astype(np.float32))' + ], + 'construct': ['opt_{var_0} = self.{var_0}({inputs},self.{var_0}_bias)'] + } + } + rewrite_data = { + 'var_0': { + 'operation': 'P.TensorAdd', + 'output_type': 'tensor', + 'variable_name': "add", 'inputs': ["x"], 'args': {}, 'weights': {}, + 'trainable_params': {} + }, + 'metadata': { + 'source': 'probs/BiasAdd', 'operation': 'Add', 'scope': 'Model/Add', + 'inputs': ['probs/MatMul:0', 'probs/BiasAdd/ReadVariableOp:0'], 'inputs_shape': (1, 1000), + 'outputs': ['probs/BiasAdd:0'], 'outputs_shape': [1, 1000], 'precursor_nodes': ['probs/MatMul'], + 'successor_nodes': ['probs/Softmax'], 'attributes': {} + } + } + fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], + outputs_mapping=((0, 0),)) + code = fragment() + init = code[0] + construct = code[1] + self.assertEqual(init, ['self.add = P.TensorAdd()', + 'self.add_bias = Tensor(np.random.rand(*(1000,)).astype(np.float32))']) + self.assertEqual(construct, ['opt_add = self.add(x,self.add_bias)']) + self.assertEqual(fragment.get_outputs_by_idx(0), "opt_add") + + def test_transpose(self): + """Test transpose like operation's template.""" + template = { + 'var_0': { + 'init': ['self.{var_0} = P.Transpose()'], + 'construct': ['opt_{var_0} = self.{var_0}({inputs}, (0, 3, 1, 2))'] + } + } + rewrite_data = { + 'var_0': { + 'operation': 'P.Transpose', + 'output_type': 'tensor', + 'variable_name': "transpose", 'inputs': ["x"], 'args': {}, 'weights': {}, + 'trainable_params': {} + } + } + fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], + outputs_mapping=((0, 0),)) + code = fragment() + init = code[0] + construct = code[1] + self.assertEqual(init, ['self.transpose = P.Transpose()']) + self.assertEqual(construct, ['opt_transpose = self.transpose(x, (0, 3, 1, 2))']) + self.assertEqual(fragment.get_outputs_by_idx(0), "opt_transpose") + + def test_split(self): + """Test split like operation's template.""" + template = { + 'var_0': { + 'init': ['self.{var_0} = P.Split(axis={axis}, output_num={output_num})'], + 'construct': ['opt_{var_0} = self.{var_0}({inputs})'] + } + } + rewrite_data = { + 'var_0': { + 'operation': 'P.Split', + 'variable_name': "split", + 'output_type': 'array', + 'inputs': ["x"], + 'args': {"axis": 1, "output_num": 2}, 'weights': {}, + 'trainable_params': {} + } + } + fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], + outputs_mapping=((0, 0),)) + code = fragment() + init = code[0] + construct = code[1] + self.assertEqual(init, ['self.split = P.Split(axis=1, output_num=2)']) + self.assertEqual(construct, ['opt_split = self.split(x)']) + self.assertEqual(fragment.get_outputs_by_idx(0, 1), 'opt_split[1]')