| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -14,26 +14,10 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define CodeLine object.""" | """Define CodeLine object.""" | ||||
| import abc | import abc | ||||
| import re | |||||
| from typing import List, Tuple | |||||
| class TrainableParams: | |||||
| """Trainable parameters.""" | |||||
| def __init__(self, shape, dtype, reference): | |||||
| self.param_name = None | |||||
| self.shape = shape | |||||
| self.dtype = dtype | |||||
| self.reference = reference # Weight name in global npy. | |||||
| class CodeSetting: | |||||
| """Code generation settings.""" | |||||
| def __init__(self): | |||||
| self.output_vars_suffix = [] | |||||
| self.operation_input_type = None # Construct input type, tensor or list. | |||||
| self.operation_extra_input = dict() # `values` in original setting dict. | |||||
| self.operation_extra_tensor = None # For `MatMul`, `BiasAdd` op, need a tensor | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| class Fragment(abc.ABC): | class Fragment(abc.ABC): | ||||
| @@ -222,3 +206,144 @@ class ModuleFragment(Fragment): | |||||
| super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, | super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, | ||||
| input_shape=input_shape, output_shape=output_shape, | input_shape=input_shape, output_shape=output_shape, | ||||
| settings=settings) | settings=settings) | ||||
| class NewFragment: | |||||
| """ | |||||
| Fragment definition for MindSpore code generation. | |||||
| Args: | |||||
| data_entity (dict): Required data by operations. The format of `data_entity` is as follow: | |||||
| { | |||||
| "var1": { | |||||
| "metadata": { # ONNX Metadata | |||||
| "operation": "Conv2d", | |||||
| "source": "conv_pw_13/Conv2D", | |||||
| "attributes": { | |||||
| # Put original onnx attributes here. | |||||
| } | |||||
| }, | |||||
| "variable_name": None, | |||||
| "inputs": [], | |||||
| "output_type": "tensor" | "array", | |||||
| "args": {"in_channels": 768, "out_channels": 1024}, | |||||
| "trainable_params": {"weight": "Parameter(Tensor(GLOBAL_W[NAME]))"} | |||||
| }, | |||||
| "var2": { | |||||
| "variable_name": "pad", | |||||
| "args": {"padding": [0, 1, 1, 0], "mode": "SAME"} | |||||
| } | |||||
| } | |||||
| code_template (dict): Code template generated by mapper. The format of `code_template` is as follow: | |||||
| { | |||||
| "var1": { | |||||
| "init": [ | |||||
| "self.{var1} = nn.Conv2d(in_channels={in_channels})", | |||||
| "self.{var1}.weight = {weight}" | |||||
| ], | |||||
| "construct": [ | |||||
| "opt_{var1} = self.{var1}({inputs}[, extra])" | |||||
| ] | |||||
| }, | |||||
| "var2": { | |||||
| "init": [ | |||||
| "self.{var2} = nn.Pad(padding={padding}, mode={mode})" | |||||
| ], | |||||
| "construct": [ | |||||
| "opt_{var2} = self.{var2}(opt_{var1}[, extra])" | |||||
| ] | |||||
| } | |||||
| } | |||||
| outputs (list[str]): Outputs name slot list. | |||||
| outputs_mapping (tuple): Outputs index mapping between ir node and MindSpore operation. | |||||
| """ | |||||
| def __init__(self, data_entity: dict, code_template: dict, outputs: List[str], outputs_mapping): | |||||
| self.exchange_msg = data_entity | |||||
| self._code_template = code_template | |||||
| self.inputs = [] | |||||
| self._outputs = outputs | |||||
| self.outputs_mapping = outputs_mapping | |||||
| self.format_args = dict() | |||||
| def _get_outputs(self): | |||||
| """ | |||||
| Get outputs of the code snippet. | |||||
| Returns: | |||||
| list[str], outputs of current code block. | |||||
| """ | |||||
| outputs = [] | |||||
| variables = { | |||||
| k: self.exchange_msg[k][ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value] | |||||
| for k in self.exchange_msg if k != ExchangeMessageKeywords.METADATA.value | |||||
| } | |||||
| for o in self._outputs: | |||||
| extractor = r".*\{(?P<var>.+)\}.*" | |||||
| var_def = re.match(extractor, o) | |||||
| if not var_def: | |||||
| raise ValueError(f"Output variable name {o} is illegal.") | |||||
| outputs.append( | |||||
| ( | |||||
| o.format(**variables), | |||||
| self.exchange_msg[var_def.group("var")][ | |||||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value] | |||||
| ) | |||||
| ) | |||||
| return outputs | |||||
| def get_outputs_by_idx(self, idx, inner_idx=-1): | |||||
| """Get outputs by idx.""" | |||||
| outputs = self._get_outputs() | |||||
| opt, opt_type = outputs[idx] | |||||
| if opt_type == ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value: | |||||
| return f"{opt}[{inner_idx}]" | |||||
| return opt | |||||
| def __call__(self) -> Tuple[List[str], List[str]]: | |||||
| """ | |||||
| Define parameter rewrite function. | |||||
| Returns: | |||||
| tuple[list[str], list[str]], init statement and construct statement. | |||||
| """ | |||||
| init_stats, call_stats = [], [] | |||||
| precursor_node_var = [None, None] | |||||
| for op_var, template in self._code_template.items(): | |||||
| if ExchangeMessageKeywords.VariableScope.value.INPUTS.value not in self.exchange_msg[op_var]: | |||||
| # It's possible inputs and precursor node both exists. | |||||
| self.exchange_msg[op_var][ExchangeMessageKeywords.VariableScope.value.ARGS.value][ | |||||
| precursor_node_var[0]] = precursor_node_var[1] | |||||
| for tpl in template[TemplateKeywords.INIT.value]: | |||||
| init_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl) | |||||
| init_stats.append(init_stat) | |||||
| for tpl in template[TemplateKeywords.CONSTRUCT.value]: | |||||
| call_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl) | |||||
| call_stats.append(call_stat) | |||||
| precursor_node_var = op_var, self.exchange_msg[op_var].get( | |||||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value) | |||||
| return init_stats, call_stats | |||||
| @staticmethod | |||||
| def _rewrite(var, data, template: str) -> str: | |||||
| """ | |||||
| Backfill data into code template. | |||||
| Args: | |||||
| var (str): Current operation variable name. | |||||
| data (dict): Data to be written. | |||||
| template (str): Code template. | |||||
| Returns: | |||||
| str, single code line. | |||||
| """ | |||||
| rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]} | |||||
| if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data: | |||||
| rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join( | |||||
| data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) | |||||
| if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: | |||||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) | |||||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value]) | |||||
| return template.format(**{ | |||||
| k: str(rewrite_data[k]) for k in rewrite_data | |||||
| }) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -209,3 +209,10 @@ def get_framework_type(model_path): | |||||
| raise error | raise error | ||||
| return framework_type | return framework_type | ||||
| def reset_init_or_construct(template, variable_slot, new_data, scope): | |||||
| """Reset init statement.""" | |||||
| template[variable_slot][scope].clear() | |||||
| template[variable_slot][scope] += new_data | |||||
| return template | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -42,6 +42,47 @@ ONNX_MIN_VER = "1.8.0" | |||||
| TF2ONNX_MIN_VER = "1.7.1" | TF2ONNX_MIN_VER = "1.7.1" | ||||
| ONNXRUNTIME_MIN_VER = "1.5.2" | ONNXRUNTIME_MIN_VER = "1.5.2" | ||||
| @unique | |||||
| class TemplateKeywords(Enum): | |||||
| """Define keywords in template message.""" | |||||
| INIT = "init" | |||||
| CONSTRUCT = "construct" | |||||
| @unique | |||||
| class ExchangeMessageKeywords(Enum): | |||||
| """Define keywords in exchange message.""" | |||||
| METADATA = "metadata" | |||||
| @unique | |||||
| class MetadataScope(Enum): | |||||
| """Define metadata scope keywords in exchange message.""" | |||||
| SOURCE = "source" | |||||
| OPERATION = "operation" | |||||
| INPUTS = "inputs" | |||||
| INPUTS_SHAPE = "inputs_shape" | |||||
| OUTPUTS = "outputs" | |||||
| OUTPUTS_SHAPE = "outputs_shape" | |||||
| PRECURSOR = "precursor_nodes" | |||||
| SUCCESSOR = "successor_nodes" | |||||
| ATTRS = "attributes" | |||||
| SCOPE = "scope" | |||||
| @unique | |||||
| class VariableScope(Enum): | |||||
| """Define variable scope keywords in exchange message.""" | |||||
| OPERATION = "operation" | |||||
| VARIABLE_NAME = "variable_name" | |||||
| OUTPUT_TYPE = "output_type" | |||||
| TSR_TYPE = "tensor" | |||||
| ARR_TYPE = "array" | |||||
| INPUTS = "inputs" | |||||
| ARGS = "args" | |||||
| WEIGHTS = "weights" | |||||
| TRAINABLE_PARAMS = "trainable_params" | |||||
| BINARY_HEADER_PYTORCH_FILE = \ | BINARY_HEADER_PYTORCH_FILE = \ | ||||
| b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' | b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' | ||||
| TENSORFLOW_MODEL_SUFFIX = "pb" | TENSORFLOW_MODEL_SUFFIX = "pb" | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -18,8 +18,9 @@ __all__ = ["batch_add_nodes"] | |||||
| import re | import re | ||||
| import copy | import copy | ||||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords | |||||
| from .generator import Generator, CodeStruct | from .generator import Generator, CodeStruct | ||||
| from ..common.code_fragment import CodeFragment | |||||
| def _tf_model_node_name_reformat(node, node_name): | def _tf_model_node_name_reformat(node, node_name): | ||||
| @@ -34,7 +35,6 @@ def _tf_model_node_name_reformat(node, node_name): | |||||
| str, re-formatted node name. | str, re-formatted node name. | ||||
| """ | """ | ||||
| scope_name = node.scope_name | scope_name = node.scope_name | ||||
| new_name = None | |||||
| regex = r"(?P<parent>.+/)(?P<op>\w+)" | regex = r"(?P<parent>.+/)(?P<op>\w+)" | ||||
| match = re.match(regex, scope_name) | match = re.match(regex, scope_name) | ||||
| parent = match.group("parent") | parent = match.group("parent") | ||||
| @@ -79,6 +79,32 @@ def batch_add_nodes(graph_obj, mapper) -> Generator: | |||||
| return generator_inst | return generator_inst | ||||
| def _supply_graph_info(node, external_inputs): | |||||
| """ | |||||
| Supply IR graph node info into metadata. | |||||
| Args: | |||||
| node (GraphNode): Graph node instance. | |||||
| external_inputs (list[str]): External inputs in ONNX ir. | |||||
| Returns: | |||||
| dict, metadata. | |||||
| """ | |||||
| precursors = _combine_external_inputs_with_precursor_nodes(node, external_inputs) | |||||
| return { | |||||
| ExchangeMessageKeywords.MetadataScope.value.SOURCE.value: node.ir_node_name, | |||||
| ExchangeMessageKeywords.MetadataScope.value.OPERATION.value: node.ir_node_operation, | |||||
| ExchangeMessageKeywords.MetadataScope.value.SCOPE.value: node.scope_name, | |||||
| ExchangeMessageKeywords.MetadataScope.value.INPUTS.value: node.ir_node_inputs, | |||||
| ExchangeMessageKeywords.MetadataScope.value.INPUTS_SHAPE.value: node.input_shape, | |||||
| ExchangeMessageKeywords.MetadataScope.value.OUTPUTS.value: node.ir_node_outputs, | |||||
| ExchangeMessageKeywords.MetadataScope.value.OUTPUTS_SHAPE.value: node.output_shape, | |||||
| ExchangeMessageKeywords.MetadataScope.value.PRECURSOR.value: precursors, | |||||
| ExchangeMessageKeywords.MetadataScope.value.SUCCESSOR.value: node.ir_node_successor, | |||||
| ExchangeMessageKeywords.MetadataScope.value.ATTRS.value: node.node_params, | |||||
| } | |||||
| def _convert_params(node, mapper): | def _convert_params(node, mapper): | ||||
| """ | """ | ||||
| Call mapper to convert node's params from ONNX to MindSpore. | Call mapper to convert node's params from ONNX to MindSpore. | ||||
| @@ -88,10 +114,8 @@ def _convert_params(node, mapper): | |||||
| mapper (Mapper): The mapper instance which indicating conversion method. | mapper (Mapper): The mapper instance which indicating conversion method. | ||||
| Returns: | Returns: | ||||
| str, op name in MindSpore | |||||
| dict, MindSpore parameters | |||||
| dict, MindSpore settings | |||||
| dict, weights of the node | |||||
| tuple[str, dict, dict, dict], op name in MindSpore, MindSpore parameters, | |||||
| MindSpore settings and weights of the node. | |||||
| """ | """ | ||||
| params = copy.deepcopy(node.node_params) | params = copy.deepcopy(node.node_params) | ||||
| params.update({"input_shape": node.input_shape, | params.update({"input_shape": node.input_shape, | ||||
| @@ -109,3 +133,24 @@ def _convert_params(node, mapper): | |||||
| return op_in_ms, ms_params, ms_settings, weights | return op_in_ms, ms_params, ms_settings, weights | ||||
| return node.op_name, node.node_params, dict(), dict() | return node.op_name, node.node_params, dict(), dict() | ||||
| def _combine_external_inputs_with_precursor_nodes(node, external_inputs): | |||||
| """ | |||||
| User_provided_input_nodes. | |||||
| Args: | |||||
| node (OnnxGraphNode): Node instance. | |||||
| external_inputs (list[str]): Inputs in onnx ir. | |||||
| Returns: | |||||
| list[str], precursor nodes list. | |||||
| """ | |||||
| inputs = set(node.ir_node_inputs) | |||||
| to_be_added = list(inputs & set(external_inputs)) | |||||
| precursor = node.ir_node_precursor | |||||
| # Add external inputs to precursor as the order of its inputs. | |||||
| for item in to_be_added: | |||||
| node_idx = node.ir_node_inputs.index(item) | |||||
| precursor.insert(node_idx, item) | |||||
| return precursor | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -19,6 +19,7 @@ import json | |||||
| import os | import os | ||||
| from typing import Dict | from typing import Dict | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| CONFIG_JSON = "onnx_to_ms.json" | CONFIG_JSON = "onnx_to_ms.json" | ||||
| OPERATION_TABLE = os.path.join( | OPERATION_TABLE = os.path.join( | ||||
| @@ -36,6 +37,7 @@ GET_OP_NAME = "_operation_name_in_ms" | |||||
| GET_OP_PARAMS = "_convert_params" | GET_OP_PARAMS = "_convert_params" | ||||
| GET_OP_WEIGHTS = "_convert_trained_weights" | GET_OP_WEIGHTS = "_convert_trained_weights" | ||||
| GET_OP_SETTINGS = "_convert_settings" | GET_OP_SETTINGS = "_convert_settings" | ||||
| GET_OP_TEMPLATE = "_generate_snippet_template" | |||||
| class Mapper(metaclass=abc.ABCMeta): | class Mapper(metaclass=abc.ABCMeta): | ||||
| @@ -44,7 +46,7 @@ class Mapper(metaclass=abc.ABCMeta): | |||||
| @staticmethod | @staticmethod | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def _operation_name_in_ms(*args, **kwargs): | def _operation_name_in_ms(*args, **kwargs): | ||||
| """Corresponding operation name in mindspore.""" | |||||
| """Corresponding operation name in MindSpore.""" | |||||
| @staticmethod | @staticmethod | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| @@ -66,6 +68,11 @@ class Mapper(metaclass=abc.ABCMeta): | |||||
| def convert(cls, op_name: str, params: Dict, weights: Dict = None): | def convert(cls, op_name: str, params: Dict, weights: Dict = None): | ||||
| """Convert third party operation's param into MindSpore operation.""" | """Convert third party operation's param into MindSpore operation.""" | ||||
| @staticmethod | |||||
| @abc.abstractmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| """Generate code template according to node info.""" | |||||
| class ONNXToMindSporeMapper(Mapper, abc.ABC): | class ONNXToMindSporeMapper(Mapper, abc.ABC): | ||||
| """ONNX operation to MindSpore.""" | """ONNX operation to MindSpore.""" | ||||
| @@ -131,3 +138,36 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| op = kwargs.get("operation") | |||||
| args = kwargs.get("converted_params") | |||||
| weights = kwargs.get("weights") | |||||
| if not op: | |||||
| raise ValueError("Can not get MindSpore operation name.") | |||||
| variable_slot = "var_0" | |||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | |||||
| template = { | |||||
| variable_slot: { | |||||
| TemplateKeywords.INIT.value: [init_template], | |||||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||||
| } | |||||
| } | |||||
| exchange_msg = { | |||||
| variable_slot: { | |||||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||||
| } | |||||
| } | |||||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||||
| outputs_mapping = ((0, 0),) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class BatchNormMapper(ONNXToMindSporeMapper): | class BatchNormMapper(ONNXToMindSporeMapper): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -14,9 +14,9 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| import numpy as np | import numpy as np | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from ....common import utils | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | |||||
| def _convert_padding(**kwargs): | def _convert_padding(**kwargs): | ||||
| @@ -83,7 +83,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| auto_pad = None | auto_pad = None | ||||
| if params.get("auto_pad") is not None: | if params.get("auto_pad") is not None: | ||||
| auto_pad = utils.convert_bytes_string_to_string(params.get("auto_pad")) | |||||
| auto_pad = convert_bytes_string_to_string(params.get("auto_pad")) | |||||
| # tmp tf translated ver. mapping | # tmp tf translated ver. mapping | ||||
| if isinstance(params.get('dilations'), list): | if isinstance(params.get('dilations'), list): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class DenseMapper(ONNXToMindSporeMapper): | class DenseMapper(ONNXToMindSporeMapper): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class FlattenMapper(ONNXToMindSporeMapper): | class FlattenMapper(ONNXToMindSporeMapper): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class GlobalPoolMapper(ONNXToMindSporeMapper): | class GlobalPoolMapper(ONNXToMindSporeMapper): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting, Tensor, get_dtype | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| class MatMulMapper(ONNXToMindSporeMapper): | class MatMulMapper(ONNXToMindSporeMapper): | ||||
| @@ -43,3 +45,30 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||||
| ref = t_name | ref = t_name | ||||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | ||||
| dtype=get_dtype(tensor), reference=ref)) | dtype=get_dtype(tensor), reference=ref)) | ||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| op = kwargs.get("operation") | |||||
| args = kwargs.get("converted_params") | |||||
| weights = kwargs.get("weights") | |||||
| if not weights: | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| weight = list(weights.items())[0] | |||||
| _, tensor = weight | |||||
| variable_slot = "var_0" | |||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||||
| f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||||
| f"self.{{{variable_slot}}}_w)" | |||||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||||
| TemplateKeywords.INIT.value) | |||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||||
| TemplateKeywords.CONSTRUCT.value) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| def _padding_format_convert(padding: list): | def _padding_format_convert(padding: list): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class PoolMapper(ONNXToMindSporeMapper): | class PoolMapper(ONNXToMindSporeMapper): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class ReLUMapper(ONNXToMindSporeMapper): | class ReLUMapper(ONNXToMindSporeMapper): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class SigmoidMapper(ONNXToMindSporeMapper): | class SigmoidMapper(ONNXToMindSporeMapper): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class SoftmaxMapper(ONNXToMindSporeMapper): | class SoftmaxMapper(ONNXToMindSporeMapper): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting, Tensor, get_dtype | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype | |||||
| class AddMapper(ONNXToMindSporeMapper): | class AddMapper(ONNXToMindSporeMapper): | ||||
| @@ -43,3 +45,30 @@ class AddMapper(ONNXToMindSporeMapper): | |||||
| ref = t_name | ref = t_name | ||||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | ||||
| dtype=get_dtype(tensor), reference=ref)) | dtype=get_dtype(tensor), reference=ref)) | ||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| op = kwargs.get("operation") | |||||
| args = kwargs.get("converted_params") | |||||
| weights = kwargs.get("weights") | |||||
| if not weights: | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| bias = list(weights.items())[0] | |||||
| _, tensor = bias | |||||
| variable_slot = "var_0" | |||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||||
| init_tensor = f"self.{{{variable_slot}}}_bias = " \ | |||||
| f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||||
| f"self.{{{variable_slot}}}_bias)" | |||||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||||
| TemplateKeywords.INIT.value) | |||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||||
| TemplateKeywords.CONSTRUCT.value) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -14,8 +14,10 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | from mindinsight.mindconverter.graph_based_converter.constant import InputType | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class ConcatMapper(ONNXToMindSporeMapper): | class ConcatMapper(ONNXToMindSporeMapper): | ||||
| @@ -38,3 +40,14 @@ class ConcatMapper(ONNXToMindSporeMapper): | |||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| input_type = InputType.LIST.value | input_type = InputType.LIST.value | ||||
| return Setting(op_ipt_type=input_type) | return Setting(op_ipt_type=input_type) | ||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| variable_slot = "var_0" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}},))" | |||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||||
| TemplateKeywords.CONSTRUCT.value) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting, Tensor, get_dtype | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype | |||||
| class MulMapper(ONNXToMindSporeMapper): | class MulMapper(ONNXToMindSporeMapper): | ||||
| @@ -40,3 +42,29 @@ class MulMapper(ONNXToMindSporeMapper): | |||||
| ref, tensor = list(weights.items())[0] | ref, tensor = list(weights.items())[0] | ||||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | ||||
| dtype=get_dtype(tensor), reference=ref)) | dtype=get_dtype(tensor), reference=ref)) | ||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| op = kwargs.get("operation") | |||||
| args = kwargs.get("converted_params") | |||||
| weights = kwargs.get("weights") | |||||
| if not weights: | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| weight = list(weights.items())[0] | |||||
| _, tensor = weight | |||||
| variable_slot = "var_0" | |||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||||
| init_tensor = f"self.{{{variable_slot}}}_w = Tensor(np.random.uniform(0, 1, {tensor.shape})" \ | |||||
| f".astype(np.{tensor.dtype}))" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||||
| f"self.{{{variable_slot}}}_w)" | |||||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||||
| TemplateKeywords.INIT.value) | |||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||||
| TemplateKeywords.CONSTRUCT.value) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class ReduceMeanMapper(ONNXToMindSporeMapper): | class ReduceMeanMapper(ONNXToMindSporeMapper): | ||||
| @@ -42,3 +44,20 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): | |||||
| else: | else: | ||||
| axis = tuple() | axis = tuple() | ||||
| return Setting(op_extra_input={'axis': axis}) | return Setting(op_extra_input={'axis': axis}) | ||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| raw_params = kwargs.get("raw_params") | |||||
| if raw_params.get('axes'): | |||||
| axis = raw_params['axes'][0] if len(raw_params['axes']) == 1 else tuple(raw_params['axes']) | |||||
| else: | |||||
| axis = tuple() | |||||
| variable_slot = "var_0" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {axis})" | |||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||||
| TemplateKeywords.CONSTRUCT.value) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class ReshapeMapper(ONNXToMindSporeMapper): | class ReshapeMapper(ONNXToMindSporeMapper): | ||||
| @@ -52,3 +54,20 @@ class ReshapeMapper(ONNXToMindSporeMapper): | |||||
| shape = [-1] | shape = [-1] | ||||
| shape += list(weights.values())[0][1:].tolist() | shape += list(weights.values())[0][1:].tolist() | ||||
| return Setting(op_extra_input={"shape": tuple(shape)}) | return Setting(op_extra_input={"shape": tuple(shape)}) | ||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| weights = kwargs.get("weights") | |||||
| if len(weights) > 1: | |||||
| raise ValueError("For reshape, `weights` length should equal to 1.") | |||||
| shape = [-1] | |||||
| shape += list(weights.values())[0][1:].tolist() | |||||
| variable_slot = "var_0" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(shape)})" | |||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||||
| TemplateKeywords.CONSTRUCT.value) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from ....common import utils | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | |||||
| class ResizeMapper(ONNXToMindSporeMapper): | class ResizeMapper(ONNXToMindSporeMapper): | ||||
| @@ -26,11 +26,11 @@ class ResizeMapper(ONNXToMindSporeMapper): | |||||
| params = kwargs.get("params") | params = kwargs.get("params") | ||||
| onnx_coordinate_transform = params.get("coordinate_transformation_mode") | onnx_coordinate_transform = params.get("coordinate_transformation_mode") | ||||
| if onnx_coordinate_transform is not None: | if onnx_coordinate_transform is not None: | ||||
| onnx_coordinate_transform = utils.convert_bytes_string_to_string(onnx_coordinate_transform) | |||||
| onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform) | |||||
| interpolation_mode = params.get("mode") | interpolation_mode = params.get("mode") | ||||
| if interpolation_mode is not None: | if interpolation_mode is not None: | ||||
| interpolation_mode = utils.convert_bytes_string_to_string(interpolation_mode) | |||||
| interpolation_mode = convert_bytes_string_to_string(interpolation_mode) | |||||
| # Define which MindSpore Resize operator to be used | # Define which MindSpore Resize operator to be used | ||||
| if interpolation_mode == "linear": | if interpolation_mode == "linear": | ||||
| @@ -54,7 +54,7 @@ class ResizeMapper(ONNXToMindSporeMapper): | |||||
| onnx_coordinate_transform = params.get("coordinate_transformation_mode") | onnx_coordinate_transform = params.get("coordinate_transformation_mode") | ||||
| if onnx_coordinate_transform is not None: | if onnx_coordinate_transform is not None: | ||||
| onnx_coordinate_transform = utils.convert_bytes_string_to_string(onnx_coordinate_transform) | |||||
| onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform) | |||||
| if onnx_coordinate_transform == "align_corners" or "half_pixel" in onnx_coordinate_transform: | if onnx_coordinate_transform == "align_corners" or "half_pixel" in onnx_coordinate_transform: | ||||
| align_corners = True | align_corners = True | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class SliceMapper(ONNXToMindSporeMapper): | class SliceMapper(ONNXToMindSporeMapper): | ||||
| @@ -41,3 +43,21 @@ class SliceMapper(ONNXToMindSporeMapper): | |||||
| starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False) | starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False) | ||||
| return Setting(op_extra_input={"begin": tuple([i[0] for i in starts]), | return Setting(op_extra_input={"begin": tuple([i[0] for i in starts]), | ||||
| "size": tuple(opt_shape)}) | "size": tuple(opt_shape)}) | ||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| weights = list(kwargs.get("weights").values()) # start, end, axis | |||||
| opt_shape = kwargs["raw_params"]["output_shape"] | |||||
| if not weights: | |||||
| raise ValueError("Cannot get required params from slice.") | |||||
| starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False) | |||||
| variable_slot = "var_0" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | |||||
| f"{tuple([i[0] for i in starts])}, {tuple(opt_shape)})" | |||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||||
| TemplateKeywords.CONSTRUCT.value) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class SplitMapper(ONNXToMindSporeMapper): | class SplitMapper(ONNXToMindSporeMapper): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class TransposeMapper(ONNXToMindSporeMapper): | class TransposeMapper(ONNXToMindSporeMapper): | ||||
| @@ -42,3 +44,17 @@ class TransposeMapper(ONNXToMindSporeMapper): | |||||
| converted_params['input_perm'] = perm | converted_params['input_perm'] = perm | ||||
| return Setting(op_extra_input=converted_params) | return Setting(op_extra_input=converted_params) | ||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| raw_params = kwargs.get("raw_params") | |||||
| perm = raw_params["perm"] | |||||
| variable_slot = "var_0" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(perm)})" | |||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||||
| TemplateKeywords.CONSTRUCT.value) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -17,11 +17,13 @@ import abc | |||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from typing import List | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from ..common.code_fragment import CodeFragment | |||||
| from ..constant import NodeType, InputType | |||||
| from ..mapper.base import Mapper | |||||
| from ...common.exceptions import NodeInputTypeNotSupportError | |||||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import NodeType, InputType | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import Mapper | |||||
| from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError | |||||
| class GraphParser(metaclass=abc.ABCMeta): | class GraphParser(metaclass=abc.ABCMeta): | ||||
| @@ -97,7 +99,6 @@ class Graph(BaseGraph, abc.ABC): | |||||
| self.model = model | self.model = model | ||||
| self._raw_input_nodes = kwargs.get("input_nodes") | self._raw_input_nodes = kwargs.get("input_nodes") | ||||
| self._raw_output_nodes = kwargs.get("output_nodes") | self._raw_output_nodes = kwargs.get("output_nodes") | ||||
| self.checkpoint = kwargs.get("checkpoint", None) | |||||
| self._nodes_collection = OrderedDict() | self._nodes_collection = OrderedDict() | ||||
| self._nodes_record = dict() | self._nodes_record = dict() | ||||
| self._shape_dict = dict() | self._shape_dict = dict() | ||||
| @@ -107,6 +108,13 @@ class Graph(BaseGraph, abc.ABC): | |||||
| self._input_shape = dict() | self._input_shape = dict() | ||||
| self._is_multi_opt_graph = False | self._is_multi_opt_graph = False | ||||
| @property | |||||
| def user_provided_input_nodes(self) -> List[str]: | |||||
| """User provided input_nodes in CLI.""" | |||||
| if not isinstance(self._raw_input_nodes, list): | |||||
| return [self._raw_input_nodes] | |||||
| return self._raw_input_nodes | |||||
| def get_input_shape(self, name): | def get_input_shape(self, name): | ||||
| """ | """ | ||||
| Get node input shape. | Get node input shape. | ||||
| @@ -285,9 +293,9 @@ class GraphNode(abc.ABC): | |||||
| self.successor_nodes = [] | self.successor_nodes = [] | ||||
| # Control dependency. | # Control dependency. | ||||
| self._deleted_in_edge = 0 | self._deleted_in_edge = 0 | ||||
| # Source node in pytorch. | |||||
| self._src_node = str(node) if node else None | |||||
| # Original operation name in pytorch. | |||||
| # Source node in ONNX. | |||||
| self._src_node = node if node else None | |||||
| # Original operation name in ONNX. | |||||
| self._op_name = None | self._op_name = None | ||||
| self._op_params = dict() | self._op_params = dict() | ||||
| self._scope_name = None | self._scope_name = None | ||||
| @@ -311,6 +319,40 @@ class GraphNode(abc.ABC): | |||||
| # Is in multi output graph. | # Is in multi output graph. | ||||
| self._is_in_multi_opt_graph = False | self._is_in_multi_opt_graph = False | ||||
| @property | |||||
| def ir_node_name(self): | |||||
| """Getter of ir node's name.""" | |||||
| return self._src_node.name | |||||
| @property | |||||
| def ir_node_operation(self): | |||||
| """Getter of ir node's operation.""" | |||||
| return self._src_node.op_type | |||||
| @property | |||||
| def ir_node_inputs(self): | |||||
| """Getter of ir node's inputs.""" | |||||
| return list(self._src_node.input_name_list) | |||||
| @property | |||||
| def ir_node_outputs(self): | |||||
| """Getter of ir node's outputs.""" | |||||
| return list(self._src_node.output_name_list) | |||||
| @property | |||||
| def ir_node_precursor(self): | |||||
| """Getter of ir node's precursor.""" | |||||
| return [ | |||||
| v.name for _, v in self._src_node.precursor_onnx_node_dict.items() | |||||
| ] | |||||
| @property | |||||
| def ir_node_successor(self): | |||||
| """Getter of ir node's successor.""" | |||||
| return [ | |||||
| v.name for _, v in self._src_node.successor_onnx_node_dict.items() | |||||
| ] | |||||
| @property | @property | ||||
| def weight(self): | def weight(self): | ||||
| return self._weight | return self._weight | ||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Unit test for mindconverter.graph_based_converter.common interface.""" | |||||
| @@ -0,0 +1,145 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Test fragment.""" | |||||
| from unittest import TestCase | |||||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment as Fragment | |||||
| class TestFragment(TestCase): | |||||
| """Tester of fragment.""" | |||||
| def test_matmul(self): | |||||
| """Test matmul like operation's template.""" | |||||
| template = { | |||||
| 'var_0': { | |||||
| 'init': [ | |||||
| 'self.{var_0} = nn.MatMul()', | |||||
| 'self.{var_0}_w = Tensor(np.random.rand(*(2048, 1000)).astype(np.float32))' | |||||
| ], | |||||
| 'construct': ['opt_{var_0} = self.{var_0}({inputs},self.{var_0}_w)'] | |||||
| } | |||||
| } | |||||
| rewrite_data = { | |||||
| 'var_0': { | |||||
| 'operation': 'nn.MatMul', | |||||
| 'output_type': 'tensor', | |||||
| 'variable_name': "matmul", 'inputs': ["x"], 'args': {}, | |||||
| 'weights': {}, | |||||
| 'trainable_params': {} | |||||
| }, | |||||
| 'metadata': { | |||||
| 'source': 'probs/MatMul', 'operation': 'MatMul', 'scope': 'Model/MatMul', | |||||
| 'inputs': ['avg_pool/Mean:0', 'probs/MatMul/ReadVariableOp:0'], | |||||
| 'inputs_shape': (1, 2048), 'outputs': ['probs/MatMul:0'], 'outputs_shape': [1, 1000], | |||||
| 'precursor_nodes': ['avg_pool/Mean'], 'successor_nodes': ['probs/BiasAdd'], | |||||
| 'attributes': {} | |||||
| } | |||||
| } | |||||
| fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], | |||||
| outputs_mapping=((0, 0),)) | |||||
| code = fragment() | |||||
| init = code[0] | |||||
| construct = code[1] | |||||
| self.assertEqual(init, ['self.matmul = nn.MatMul()', | |||||
| 'self.matmul_w = Tensor(np.random.rand(*(2048, 1000)).astype(np.float32))']) | |||||
| self.assertEqual(construct, ['opt_matmul = self.matmul(x,self.matmul_w)']) | |||||
| self.assertEqual(fragment.get_outputs_by_idx(0), "opt_matmul") | |||||
| def test_biasadd(self): | |||||
| """Test biasadd like operation's template.""" | |||||
| template = { | |||||
| 'var_0': { | |||||
| 'init': [ | |||||
| 'self.{var_0} = P.TensorAdd()', | |||||
| 'self.{var_0}_bias = Tensor(np.random.rand(*(1000,)).astype(np.float32))' | |||||
| ], | |||||
| 'construct': ['opt_{var_0} = self.{var_0}({inputs},self.{var_0}_bias)'] | |||||
| } | |||||
| } | |||||
| rewrite_data = { | |||||
| 'var_0': { | |||||
| 'operation': 'P.TensorAdd', | |||||
| 'output_type': 'tensor', | |||||
| 'variable_name': "add", 'inputs': ["x"], 'args': {}, 'weights': {}, | |||||
| 'trainable_params': {} | |||||
| }, | |||||
| 'metadata': { | |||||
| 'source': 'probs/BiasAdd', 'operation': 'Add', 'scope': 'Model/Add', | |||||
| 'inputs': ['probs/MatMul:0', 'probs/BiasAdd/ReadVariableOp:0'], 'inputs_shape': (1, 1000), | |||||
| 'outputs': ['probs/BiasAdd:0'], 'outputs_shape': [1, 1000], 'precursor_nodes': ['probs/MatMul'], | |||||
| 'successor_nodes': ['probs/Softmax'], 'attributes': {} | |||||
| } | |||||
| } | |||||
| fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], | |||||
| outputs_mapping=((0, 0),)) | |||||
| code = fragment() | |||||
| init = code[0] | |||||
| construct = code[1] | |||||
| self.assertEqual(init, ['self.add = P.TensorAdd()', | |||||
| 'self.add_bias = Tensor(np.random.rand(*(1000,)).astype(np.float32))']) | |||||
| self.assertEqual(construct, ['opt_add = self.add(x,self.add_bias)']) | |||||
| self.assertEqual(fragment.get_outputs_by_idx(0), "opt_add") | |||||
| def test_transpose(self): | |||||
| """Test transpose like operation's template.""" | |||||
| template = { | |||||
| 'var_0': { | |||||
| 'init': ['self.{var_0} = P.Transpose()'], | |||||
| 'construct': ['opt_{var_0} = self.{var_0}({inputs}, (0, 3, 1, 2))'] | |||||
| } | |||||
| } | |||||
| rewrite_data = { | |||||
| 'var_0': { | |||||
| 'operation': 'P.Transpose', | |||||
| 'output_type': 'tensor', | |||||
| 'variable_name': "transpose", 'inputs': ["x"], 'args': {}, 'weights': {}, | |||||
| 'trainable_params': {} | |||||
| } | |||||
| } | |||||
| fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], | |||||
| outputs_mapping=((0, 0),)) | |||||
| code = fragment() | |||||
| init = code[0] | |||||
| construct = code[1] | |||||
| self.assertEqual(init, ['self.transpose = P.Transpose()']) | |||||
| self.assertEqual(construct, ['opt_transpose = self.transpose(x, (0, 3, 1, 2))']) | |||||
| self.assertEqual(fragment.get_outputs_by_idx(0), "opt_transpose") | |||||
| def test_split(self): | |||||
| """Test split like operation's template.""" | |||||
| template = { | |||||
| 'var_0': { | |||||
| 'init': ['self.{var_0} = P.Split(axis={axis}, output_num={output_num})'], | |||||
| 'construct': ['opt_{var_0} = self.{var_0}({inputs})'] | |||||
| } | |||||
| } | |||||
| rewrite_data = { | |||||
| 'var_0': { | |||||
| 'operation': 'P.Split', | |||||
| 'variable_name': "split", | |||||
| 'output_type': 'array', | |||||
| 'inputs': ["x"], | |||||
| 'args': {"axis": 1, "output_num": 2}, 'weights': {}, | |||||
| 'trainable_params': {} | |||||
| } | |||||
| } | |||||
| fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"], | |||||
| outputs_mapping=((0, 0),)) | |||||
| code = fragment() | |||||
| init = code[0] | |||||
| construct = code[1] | |||||
| self.assertEqual(init, ['self.split = P.Split(axis=1, output_num=2)']) | |||||
| self.assertEqual(construct, ['opt_split = self.split(x)']) | |||||
| self.assertEqual(fragment.get_outputs_by_idx(0, 1), 'opt_split[1]') | |||||