Browse Source

Add code template generation.

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
37c3a1b7e0
28 changed files with 739 additions and 106 deletions
  1. +145
    -20
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  2. +8
    -1
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  3. +42
    -1
      mindinsight/mindconverter/graph_based_converter/constant.py
  4. +52
    -7
      mindinsight/mindconverter/graph_based_converter/generator/__init__.py
  5. +42
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  6. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py
  7. +5
    -5
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  8. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py
  9. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py
  10. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py
  11. +32
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py
  12. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py
  13. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py
  14. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py
  15. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/sigmoid_mapper.py
  16. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py
  17. +32
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py
  18. +16
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py
  19. +31
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py
  20. +22
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py
  21. +22
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py
  22. +7
    -7
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py
  23. +23
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py
  24. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py
  25. +19
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py
  26. +51
    -9
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  27. +15
    -0
      tests/ut/mindconverter/graph_based_converter/common/__init__.py
  28. +145
    -0
      tests/ut/mindconverter/graph_based_converter/common/test_fragment.py

+ 145
- 20
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,26 +14,10 @@
# ============================================================================== # ==============================================================================
"""Define CodeLine object.""" """Define CodeLine object."""
import abc import abc
import re
from typing import List, Tuple



class TrainableParams:
"""Trainable parameters."""

def __init__(self, shape, dtype, reference):
self.param_name = None
self.shape = shape
self.dtype = dtype
self.reference = reference # Weight name in global npy.


class CodeSetting:
"""Code generation settings."""

def __init__(self):
self.output_vars_suffix = []
self.operation_input_type = None # Construct input type, tensor or list.
self.operation_extra_input = dict() # `values` in original setting dict.
self.operation_extra_tensor = None # For `MatMul`, `BiasAdd` op, need a tensor
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords




class Fragment(abc.ABC): class Fragment(abc.ABC):
@@ -222,3 +206,144 @@ class ModuleFragment(Fragment):
super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args,
input_shape=input_shape, output_shape=output_shape, input_shape=input_shape, output_shape=output_shape,
settings=settings) settings=settings)


class NewFragment:
"""
Fragment definition for MindSpore code generation.

Args:
data_entity (dict): Required data by operations. The format of `data_entity` is as follow:
{
"var1": {
"metadata": { # ONNX Metadata
"operation": "Conv2d",
"source": "conv_pw_13/Conv2D",
"attributes": {
# Put original onnx attributes here.
}
},
"variable_name": None,
"inputs": [],
"output_type": "tensor" | "array",
"args": {"in_channels": 768, "out_channels": 1024},
"trainable_params": {"weight": "Parameter(Tensor(GLOBAL_W[NAME]))"}
},
"var2": {
"variable_name": "pad",
"args": {"padding": [0, 1, 1, 0], "mode": "SAME"}
}
}
code_template (dict): Code template generated by mapper. The format of `code_template` is as follow:
{
"var1": {
"init": [
"self.{var1} = nn.Conv2d(in_channels={in_channels})",
"self.{var1}.weight = {weight}"
],
"construct": [
"opt_{var1} = self.{var1}({inputs}[, extra])"
]
},
"var2": {
"init": [
"self.{var2} = nn.Pad(padding={padding}, mode={mode})"
],
"construct": [
"opt_{var2} = self.{var2}(opt_{var1}[, extra])"
]
}
}
outputs (list[str]): Outputs name slot list.
outputs_mapping (tuple): Outputs index mapping between ir node and MindSpore operation.
"""

def __init__(self, data_entity: dict, code_template: dict, outputs: List[str], outputs_mapping):
self.exchange_msg = data_entity
self._code_template = code_template
self.inputs = []
self._outputs = outputs
self.outputs_mapping = outputs_mapping
self.format_args = dict()

def _get_outputs(self):
"""
Get outputs of the code snippet.

Returns:
list[str], outputs of current code block.
"""
outputs = []
variables = {
k: self.exchange_msg[k][ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]
for k in self.exchange_msg if k != ExchangeMessageKeywords.METADATA.value
}
for o in self._outputs:
extractor = r".*\{(?P<var>.+)\}.*"
var_def = re.match(extractor, o)
if not var_def:
raise ValueError(f"Output variable name {o} is illegal.")
outputs.append(
(
o.format(**variables),
self.exchange_msg[var_def.group("var")][
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value]
)
)
return outputs

def get_outputs_by_idx(self, idx, inner_idx=-1):
"""Get outputs by idx."""
outputs = self._get_outputs()
opt, opt_type = outputs[idx]
if opt_type == ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value:
return f"{opt}[{inner_idx}]"
return opt

def __call__(self) -> Tuple[List[str], List[str]]:
"""
Define parameter rewrite function.

Returns:
tuple[list[str], list[str]], init statement and construct statement.
"""
init_stats, call_stats = [], []
precursor_node_var = [None, None]
for op_var, template in self._code_template.items():
if ExchangeMessageKeywords.VariableScope.value.INPUTS.value not in self.exchange_msg[op_var]:
# It's possible inputs and precursor node both exists.
self.exchange_msg[op_var][ExchangeMessageKeywords.VariableScope.value.ARGS.value][
precursor_node_var[0]] = precursor_node_var[1]
for tpl in template[TemplateKeywords.INIT.value]:
init_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl)
init_stats.append(init_stat)
for tpl in template[TemplateKeywords.CONSTRUCT.value]:
call_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl)
call_stats.append(call_stat)
precursor_node_var = op_var, self.exchange_msg[op_var].get(
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value)
return init_stats, call_stats

@staticmethod
def _rewrite(var, data, template: str) -> str:
"""
Backfill data into code template.

Args:
var (str): Current operation variable name.
data (dict): Data to be written.
template (str): Code template.

Returns:
str, single code line.
"""
rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]}
if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data:
rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join(
data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value])
if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data:
rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value])
rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value])
return template.format(**{
k: str(rewrite_data[k]) for k in rewrite_data
})

+ 8
- 1
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -209,3 +209,10 @@ def get_framework_type(model_path):
raise error raise error


return framework_type return framework_type


def reset_init_or_construct(template, variable_slot, new_data, scope):
"""Reset init statement."""
template[variable_slot][scope].clear()
template[variable_slot][scope] += new_data
return template

+ 42
- 1
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -42,6 +42,47 @@ ONNX_MIN_VER = "1.8.0"
TF2ONNX_MIN_VER = "1.7.1" TF2ONNX_MIN_VER = "1.7.1"
ONNXRUNTIME_MIN_VER = "1.5.2" ONNXRUNTIME_MIN_VER = "1.5.2"



@unique
class TemplateKeywords(Enum):
"""Define keywords in template message."""
INIT = "init"
CONSTRUCT = "construct"


@unique
class ExchangeMessageKeywords(Enum):
"""Define keywords in exchange message."""
METADATA = "metadata"

@unique
class MetadataScope(Enum):
"""Define metadata scope keywords in exchange message."""
SOURCE = "source"
OPERATION = "operation"
INPUTS = "inputs"
INPUTS_SHAPE = "inputs_shape"
OUTPUTS = "outputs"
OUTPUTS_SHAPE = "outputs_shape"
PRECURSOR = "precursor_nodes"
SUCCESSOR = "successor_nodes"
ATTRS = "attributes"
SCOPE = "scope"

@unique
class VariableScope(Enum):
"""Define variable scope keywords in exchange message."""
OPERATION = "operation"
VARIABLE_NAME = "variable_name"
OUTPUT_TYPE = "output_type"
TSR_TYPE = "tensor"
ARR_TYPE = "array"
INPUTS = "inputs"
ARGS = "args"
WEIGHTS = "weights"
TRAINABLE_PARAMS = "trainable_params"


BINARY_HEADER_PYTORCH_FILE = \ BINARY_HEADER_PYTORCH_FILE = \
b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00'
TENSORFLOW_MODEL_SUFFIX = "pb" TENSORFLOW_MODEL_SUFFIX = "pb"


+ 52
- 7
mindinsight/mindconverter/graph_based_converter/generator/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -18,8 +18,9 @@ __all__ = ["batch_add_nodes"]
import re import re
import copy import copy


from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords
from .generator import Generator, CodeStruct from .generator import Generator, CodeStruct
from ..common.code_fragment import CodeFragment




def _tf_model_node_name_reformat(node, node_name): def _tf_model_node_name_reformat(node, node_name):
@@ -34,7 +35,6 @@ def _tf_model_node_name_reformat(node, node_name):
str, re-formatted node name. str, re-formatted node name.
""" """
scope_name = node.scope_name scope_name = node.scope_name
new_name = None
regex = r"(?P<parent>.+/)(?P<op>\w+)" regex = r"(?P<parent>.+/)(?P<op>\w+)"
match = re.match(regex, scope_name) match = re.match(regex, scope_name)
parent = match.group("parent") parent = match.group("parent")
@@ -79,6 +79,32 @@ def batch_add_nodes(graph_obj, mapper) -> Generator:
return generator_inst return generator_inst




def _supply_graph_info(node, external_inputs):
"""
Supply IR graph node info into metadata.

Args:
node (GraphNode): Graph node instance.
external_inputs (list[str]): External inputs in ONNX ir.

Returns:
dict, metadata.
"""
precursors = _combine_external_inputs_with_precursor_nodes(node, external_inputs)
return {
ExchangeMessageKeywords.MetadataScope.value.SOURCE.value: node.ir_node_name,
ExchangeMessageKeywords.MetadataScope.value.OPERATION.value: node.ir_node_operation,
ExchangeMessageKeywords.MetadataScope.value.SCOPE.value: node.scope_name,
ExchangeMessageKeywords.MetadataScope.value.INPUTS.value: node.ir_node_inputs,
ExchangeMessageKeywords.MetadataScope.value.INPUTS_SHAPE.value: node.input_shape,
ExchangeMessageKeywords.MetadataScope.value.OUTPUTS.value: node.ir_node_outputs,
ExchangeMessageKeywords.MetadataScope.value.OUTPUTS_SHAPE.value: node.output_shape,
ExchangeMessageKeywords.MetadataScope.value.PRECURSOR.value: precursors,
ExchangeMessageKeywords.MetadataScope.value.SUCCESSOR.value: node.ir_node_successor,
ExchangeMessageKeywords.MetadataScope.value.ATTRS.value: node.node_params,
}


def _convert_params(node, mapper): def _convert_params(node, mapper):
""" """
Call mapper to convert node's params from ONNX to MindSpore. Call mapper to convert node's params from ONNX to MindSpore.
@@ -88,10 +114,8 @@ def _convert_params(node, mapper):
mapper (Mapper): The mapper instance which indicating conversion method. mapper (Mapper): The mapper instance which indicating conversion method.


Returns: Returns:
str, op name in MindSpore
dict, MindSpore parameters
dict, MindSpore settings
dict, weights of the node
tuple[str, dict, dict, dict], op name in MindSpore, MindSpore parameters,
MindSpore settings and weights of the node.
""" """
params = copy.deepcopy(node.node_params) params = copy.deepcopy(node.node_params)
params.update({"input_shape": node.input_shape, params.update({"input_shape": node.input_shape,
@@ -109,3 +133,24 @@ def _convert_params(node, mapper):
return op_in_ms, ms_params, ms_settings, weights return op_in_ms, ms_params, ms_settings, weights


return node.op_name, node.node_params, dict(), dict() return node.op_name, node.node_params, dict(), dict()


def _combine_external_inputs_with_precursor_nodes(node, external_inputs):
"""
User_provided_input_nodes.

Args:
node (OnnxGraphNode): Node instance.
external_inputs (list[str]): Inputs in onnx ir.

Returns:
list[str], precursor nodes list.
"""
inputs = set(node.ir_node_inputs)
to_be_added = list(inputs & set(external_inputs))
precursor = node.ir_node_precursor
# Add external inputs to precursor as the order of its inputs.
for item in to_be_added:
node_idx = node.ir_node_inputs.index(item)
precursor.insert(node_idx, item)
return precursor

+ 42
- 2
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ import json
import os import os
from typing import Dict from typing import Dict
from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords


CONFIG_JSON = "onnx_to_ms.json" CONFIG_JSON = "onnx_to_ms.json"
OPERATION_TABLE = os.path.join( OPERATION_TABLE = os.path.join(
@@ -36,6 +37,7 @@ GET_OP_NAME = "_operation_name_in_ms"
GET_OP_PARAMS = "_convert_params" GET_OP_PARAMS = "_convert_params"
GET_OP_WEIGHTS = "_convert_trained_weights" GET_OP_WEIGHTS = "_convert_trained_weights"
GET_OP_SETTINGS = "_convert_settings" GET_OP_SETTINGS = "_convert_settings"
GET_OP_TEMPLATE = "_generate_snippet_template"




class Mapper(metaclass=abc.ABCMeta): class Mapper(metaclass=abc.ABCMeta):
@@ -44,7 +46,7 @@ class Mapper(metaclass=abc.ABCMeta):
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def _operation_name_in_ms(*args, **kwargs): def _operation_name_in_ms(*args, **kwargs):
"""Corresponding operation name in mindspore."""
"""Corresponding operation name in MindSpore."""


@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
@@ -66,6 +68,11 @@ class Mapper(metaclass=abc.ABCMeta):
def convert(cls, op_name: str, params: Dict, weights: Dict = None): def convert(cls, op_name: str, params: Dict, weights: Dict = None):
"""Convert third party operation's param into MindSpore operation.""" """Convert third party operation's param into MindSpore operation."""


@staticmethod
@abc.abstractmethod
def _generate_snippet_template(**kwargs):
"""Generate code template according to node info."""



class ONNXToMindSporeMapper(Mapper, abc.ABC): class ONNXToMindSporeMapper(Mapper, abc.ABC):
"""ONNX operation to MindSpore.""" """ONNX operation to MindSpore."""
@@ -131,3 +138,36 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
@staticmethod @staticmethod
def _convert_settings(**kwargs): def _convert_settings(**kwargs):
raise NotImplementedError raise NotImplementedError

@staticmethod
def _generate_snippet_template(**kwargs):
op = kwargs.get("operation")
args = kwargs.get("converted_params")
weights = kwargs.get("weights")
if not op:
raise ValueError("Can not get MindSpore operation name.")
variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})"
template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template],
TemplateKeywords.CONSTRUCT.value: [construct_template]
}
}
exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class BatchNormMapper(ONNXToMindSporeMapper): class BatchNormMapper(ONNXToMindSporeMapper):


+ 5
- 5
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,9 +14,9 @@
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
import numpy as np import numpy as np
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from ....common import utils
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string




def _convert_padding(**kwargs): def _convert_padding(**kwargs):
@@ -83,7 +83,7 @@ class ConvMapper(ONNXToMindSporeMapper):


auto_pad = None auto_pad = None
if params.get("auto_pad") is not None: if params.get("auto_pad") is not None:
auto_pad = utils.convert_bytes_string_to_string(params.get("auto_pad"))
auto_pad = convert_bytes_string_to_string(params.get("auto_pad"))


# tmp tf translated ver. mapping # tmp tf translated ver. mapping
if isinstance(params.get('dilations'), list): if isinstance(params.get('dilations'), list):


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class DenseMapper(ONNXToMindSporeMapper): class DenseMapper(ONNXToMindSporeMapper):


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class FlattenMapper(ONNXToMindSporeMapper): class FlattenMapper(ONNXToMindSporeMapper):


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class GlobalPoolMapper(ONNXToMindSporeMapper): class GlobalPoolMapper(ONNXToMindSporeMapper):


+ 32
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting, Tensor, get_dtype
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords




class MatMulMapper(ONNXToMindSporeMapper): class MatMulMapper(ONNXToMindSporeMapper):
@@ -43,3 +45,30 @@ class MatMulMapper(ONNXToMindSporeMapper):
ref = t_name ref = t_name
return Setting(op_extra_tensor=Tensor(shape=tensor.shape, return Setting(op_extra_tensor=Tensor(shape=tensor.shape,
dtype=get_dtype(tensor), reference=ref)) dtype=get_dtype(tensor), reference=ref))

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
op = kwargs.get("operation")
args = kwargs.get("converted_params")
weights = kwargs.get("weights")
if not weights:
return template, exchange_msg, outputs_list, outputs_mapping

weight = list(weights.items())[0]
_, tensor = weight

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
init_tensor = f"self.{{{variable_slot}}}_w = " \
f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \
f"self.{{{variable_slot}}}_w)"
template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor],
TemplateKeywords.INIT.value)
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)

return template, exchange_msg, outputs_list, outputs_mapping

+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




def _padding_format_convert(padding: list): def _padding_format_convert(padding: list):


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class PoolMapper(ONNXToMindSporeMapper): class PoolMapper(ONNXToMindSporeMapper):


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class ReLUMapper(ONNXToMindSporeMapper): class ReLUMapper(ONNXToMindSporeMapper):


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/sigmoid_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class SigmoidMapper(ONNXToMindSporeMapper): class SigmoidMapper(ONNXToMindSporeMapper):


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class SoftmaxMapper(ONNXToMindSporeMapper): class SoftmaxMapper(ONNXToMindSporeMapper):


+ 32
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting, Tensor, get_dtype
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype




class AddMapper(ONNXToMindSporeMapper): class AddMapper(ONNXToMindSporeMapper):
@@ -43,3 +45,30 @@ class AddMapper(ONNXToMindSporeMapper):
ref = t_name ref = t_name
return Setting(op_extra_tensor=Tensor(shape=tensor.shape, return Setting(op_extra_tensor=Tensor(shape=tensor.shape,
dtype=get_dtype(tensor), reference=ref)) dtype=get_dtype(tensor), reference=ref))

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
op = kwargs.get("operation")
args = kwargs.get("converted_params")
weights = kwargs.get("weights")
if not weights:
return template, exchange_msg, outputs_list, outputs_mapping

bias = list(weights.items())[0]
_, tensor = bias

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
init_tensor = f"self.{{{variable_slot}}}_bias = " \
f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \
f"self.{{{variable_slot}}}_bias)"
template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor],
TemplateKeywords.INIT.value)
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)

return template, exchange_msg, outputs_list, outputs_mapping

+ 16
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,8 +14,10 @@
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from mindinsight.mindconverter.graph_based_converter.constant import InputType from mindinsight.mindconverter.graph_based_converter.constant import InputType
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class ConcatMapper(ONNXToMindSporeMapper): class ConcatMapper(ONNXToMindSporeMapper):
@@ -38,3 +40,14 @@ class ConcatMapper(ONNXToMindSporeMapper):
def _convert_settings(**kwargs): def _convert_settings(**kwargs):
input_type = InputType.LIST.value input_type = InputType.LIST.value
return Setting(op_ipt_type=input_type) return Setting(op_ipt_type=input_type)

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
variable_slot = "var_0"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}},))"
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)
return template, exchange_msg, outputs_list, outputs_mapping

+ 31
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting, Tensor, get_dtype
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype




class MulMapper(ONNXToMindSporeMapper): class MulMapper(ONNXToMindSporeMapper):
@@ -40,3 +42,29 @@ class MulMapper(ONNXToMindSporeMapper):
ref, tensor = list(weights.items())[0] ref, tensor = list(weights.items())[0]
return Setting(op_extra_tensor=Tensor(shape=tensor.shape, return Setting(op_extra_tensor=Tensor(shape=tensor.shape,
dtype=get_dtype(tensor), reference=ref)) dtype=get_dtype(tensor), reference=ref))

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
op = kwargs.get("operation")
args = kwargs.get("converted_params")
weights = kwargs.get("weights")
if not weights:
return template, exchange_msg, outputs_list, outputs_mapping

weight = list(weights.items())[0]
_, tensor = weight

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
init_tensor = f"self.{{{variable_slot}}}_w = Tensor(np.random.uniform(0, 1, {tensor.shape})" \
f".astype(np.{tensor.dtype}))"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \
f"self.{{{variable_slot}}}_w)"
template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor],
TemplateKeywords.INIT.value)
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)
return template, exchange_msg, outputs_list, outputs_mapping

+ 22
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class ReduceMeanMapper(ONNXToMindSporeMapper): class ReduceMeanMapper(ONNXToMindSporeMapper):
@@ -42,3 +44,20 @@ class ReduceMeanMapper(ONNXToMindSporeMapper):
else: else:
axis = tuple() axis = tuple()
return Setting(op_extra_input={'axis': axis}) return Setting(op_extra_input={'axis': axis})

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
raw_params = kwargs.get("raw_params")
if raw_params.get('axes'):
axis = raw_params['axes'][0] if len(raw_params['axes']) == 1 else tuple(raw_params['axes'])
else:
axis = tuple()
variable_slot = "var_0"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {axis})"
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)

return template, exchange_msg, outputs_list, outputs_mapping

+ 22
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class ReshapeMapper(ONNXToMindSporeMapper): class ReshapeMapper(ONNXToMindSporeMapper):
@@ -52,3 +54,20 @@ class ReshapeMapper(ONNXToMindSporeMapper):
shape = [-1] shape = [-1]
shape += list(weights.values())[0][1:].tolist() shape += list(weights.values())[0][1:].tolist()
return Setting(op_extra_input={"shape": tuple(shape)}) return Setting(op_extra_input={"shape": tuple(shape)})

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
weights = kwargs.get("weights")
if len(weights) > 1:
raise ValueError("For reshape, `weights` length should equal to 1.")
shape = [-1]
shape += list(weights.values())[0][1:].tolist()
variable_slot = "var_0"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(shape)})"
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)

return template, exchange_msg, outputs_list, outputs_mapping

+ 7
- 7
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from ....common import utils
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string




class ResizeMapper(ONNXToMindSporeMapper): class ResizeMapper(ONNXToMindSporeMapper):
@@ -26,11 +26,11 @@ class ResizeMapper(ONNXToMindSporeMapper):
params = kwargs.get("params") params = kwargs.get("params")
onnx_coordinate_transform = params.get("coordinate_transformation_mode") onnx_coordinate_transform = params.get("coordinate_transformation_mode")
if onnx_coordinate_transform is not None: if onnx_coordinate_transform is not None:
onnx_coordinate_transform = utils.convert_bytes_string_to_string(onnx_coordinate_transform)
onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform)


interpolation_mode = params.get("mode") interpolation_mode = params.get("mode")
if interpolation_mode is not None: if interpolation_mode is not None:
interpolation_mode = utils.convert_bytes_string_to_string(interpolation_mode)
interpolation_mode = convert_bytes_string_to_string(interpolation_mode)


# Define which MindSpore Resize operator to be used # Define which MindSpore Resize operator to be used
if interpolation_mode == "linear": if interpolation_mode == "linear":
@@ -54,7 +54,7 @@ class ResizeMapper(ONNXToMindSporeMapper):


onnx_coordinate_transform = params.get("coordinate_transformation_mode") onnx_coordinate_transform = params.get("coordinate_transformation_mode")
if onnx_coordinate_transform is not None: if onnx_coordinate_transform is not None:
onnx_coordinate_transform = utils.convert_bytes_string_to_string(onnx_coordinate_transform)
onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform)


if onnx_coordinate_transform == "align_corners" or "half_pixel" in onnx_coordinate_transform: if onnx_coordinate_transform == "align_corners" or "half_pixel" in onnx_coordinate_transform:
align_corners = True align_corners = True


+ 23
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class SliceMapper(ONNXToMindSporeMapper): class SliceMapper(ONNXToMindSporeMapper):
@@ -41,3 +43,21 @@ class SliceMapper(ONNXToMindSporeMapper):
starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False) starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False)
return Setting(op_extra_input={"begin": tuple([i[0] for i in starts]), return Setting(op_extra_input={"begin": tuple([i[0] for i in starts]),
"size": tuple(opt_shape)}) "size": tuple(opt_shape)})

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
weights = list(kwargs.get("weights").values()) # start, end, axis
opt_shape = kwargs["raw_params"]["output_shape"]
if not weights:
raise ValueError("Cannot get required params from slice.")
starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False)
variable_slot = "var_0"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \
f"{tuple([i[0] for i in starts])}, {tuple(opt_shape)})"
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)

return template, exchange_msg, outputs_list, outputs_mapping

+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class SplitMapper(ONNXToMindSporeMapper): class SplitMapper(ONNXToMindSporeMapper):


+ 19
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class TransposeMapper(ONNXToMindSporeMapper): class TransposeMapper(ONNXToMindSporeMapper):
@@ -42,3 +44,17 @@ class TransposeMapper(ONNXToMindSporeMapper):
converted_params['input_perm'] = perm converted_params['input_perm'] = perm


return Setting(op_extra_input=converted_params) return Setting(op_extra_input=converted_params)

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
raw_params = kwargs.get("raw_params")
perm = raw_params["perm"]
variable_slot = "var_0"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(perm)})"
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)

return template, exchange_msg, outputs_list, outputs_mapping

+ 51
- 9
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -17,11 +17,13 @@ import abc
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy


from typing import List

from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from ..common.code_fragment import CodeFragment
from ..constant import NodeType, InputType
from ..mapper.base import Mapper
from ...common.exceptions import NodeInputTypeNotSupportError
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment
from mindinsight.mindconverter.graph_based_converter.constant import NodeType, InputType
from mindinsight.mindconverter.graph_based_converter.mapper.base import Mapper
from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError




class GraphParser(metaclass=abc.ABCMeta): class GraphParser(metaclass=abc.ABCMeta):
@@ -97,7 +99,6 @@ class Graph(BaseGraph, abc.ABC):
self.model = model self.model = model
self._raw_input_nodes = kwargs.get("input_nodes") self._raw_input_nodes = kwargs.get("input_nodes")
self._raw_output_nodes = kwargs.get("output_nodes") self._raw_output_nodes = kwargs.get("output_nodes")
self.checkpoint = kwargs.get("checkpoint", None)
self._nodes_collection = OrderedDict() self._nodes_collection = OrderedDict()
self._nodes_record = dict() self._nodes_record = dict()
self._shape_dict = dict() self._shape_dict = dict()
@@ -107,6 +108,13 @@ class Graph(BaseGraph, abc.ABC):
self._input_shape = dict() self._input_shape = dict()
self._is_multi_opt_graph = False self._is_multi_opt_graph = False


@property
def user_provided_input_nodes(self) -> List[str]:
"""User provided input_nodes in CLI."""
if not isinstance(self._raw_input_nodes, list):
return [self._raw_input_nodes]
return self._raw_input_nodes

def get_input_shape(self, name): def get_input_shape(self, name):
""" """
Get node input shape. Get node input shape.
@@ -285,9 +293,9 @@ class GraphNode(abc.ABC):
self.successor_nodes = [] self.successor_nodes = []
# Control dependency. # Control dependency.
self._deleted_in_edge = 0 self._deleted_in_edge = 0
# Source node in pytorch.
self._src_node = str(node) if node else None
# Original operation name in pytorch.
# Source node in ONNX.
self._src_node = node if node else None
# Original operation name in ONNX.
self._op_name = None self._op_name = None
self._op_params = dict() self._op_params = dict()
self._scope_name = None self._scope_name = None
@@ -311,6 +319,40 @@ class GraphNode(abc.ABC):
# Is in multi output graph. # Is in multi output graph.
self._is_in_multi_opt_graph = False self._is_in_multi_opt_graph = False


@property
def ir_node_name(self):
"""Getter of ir node's name."""
return self._src_node.name

@property
def ir_node_operation(self):
"""Getter of ir node's operation."""
return self._src_node.op_type

@property
def ir_node_inputs(self):
"""Getter of ir node's inputs."""
return list(self._src_node.input_name_list)

@property
def ir_node_outputs(self):
"""Getter of ir node's outputs."""
return list(self._src_node.output_name_list)

@property
def ir_node_precursor(self):
"""Getter of ir node's precursor."""
return [
v.name for _, v in self._src_node.precursor_onnx_node_dict.items()
]

@property
def ir_node_successor(self):
"""Getter of ir node's successor."""
return [
v.name for _, v in self._src_node.successor_onnx_node_dict.items()
]

@property @property
def weight(self): def weight(self):
return self._weight return self._weight


+ 15
- 0
tests/ut/mindconverter/graph_based_converter/common/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit test for mindconverter.graph_based_converter.common interface."""

+ 145
- 0
tests/ut/mindconverter/graph_based_converter/common/test_fragment.py View File

@@ -0,0 +1,145 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test fragment."""
from unittest import TestCase
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment as Fragment


class TestFragment(TestCase):
"""Tester of fragment."""

def test_matmul(self):
"""Test matmul like operation's template."""
template = {
'var_0': {
'init': [
'self.{var_0} = nn.MatMul()',
'self.{var_0}_w = Tensor(np.random.rand(*(2048, 1000)).astype(np.float32))'
],
'construct': ['opt_{var_0} = self.{var_0}({inputs},self.{var_0}_w)']
}
}
rewrite_data = {
'var_0': {
'operation': 'nn.MatMul',
'output_type': 'tensor',
'variable_name': "matmul", 'inputs': ["x"], 'args': {},
'weights': {},
'trainable_params': {}
},
'metadata': {
'source': 'probs/MatMul', 'operation': 'MatMul', 'scope': 'Model/MatMul',
'inputs': ['avg_pool/Mean:0', 'probs/MatMul/ReadVariableOp:0'],
'inputs_shape': (1, 2048), 'outputs': ['probs/MatMul:0'], 'outputs_shape': [1, 1000],
'precursor_nodes': ['avg_pool/Mean'], 'successor_nodes': ['probs/BiasAdd'],
'attributes': {}
}
}
fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"],
outputs_mapping=((0, 0),))
code = fragment()
init = code[0]
construct = code[1]
self.assertEqual(init, ['self.matmul = nn.MatMul()',
'self.matmul_w = Tensor(np.random.rand(*(2048, 1000)).astype(np.float32))'])
self.assertEqual(construct, ['opt_matmul = self.matmul(x,self.matmul_w)'])
self.assertEqual(fragment.get_outputs_by_idx(0), "opt_matmul")

def test_biasadd(self):
"""Test biasadd like operation's template."""
template = {
'var_0': {
'init': [
'self.{var_0} = P.TensorAdd()',
'self.{var_0}_bias = Tensor(np.random.rand(*(1000,)).astype(np.float32))'
],
'construct': ['opt_{var_0} = self.{var_0}({inputs},self.{var_0}_bias)']
}
}
rewrite_data = {
'var_0': {
'operation': 'P.TensorAdd',
'output_type': 'tensor',
'variable_name': "add", 'inputs': ["x"], 'args': {}, 'weights': {},
'trainable_params': {}
},
'metadata': {
'source': 'probs/BiasAdd', 'operation': 'Add', 'scope': 'Model/Add',
'inputs': ['probs/MatMul:0', 'probs/BiasAdd/ReadVariableOp:0'], 'inputs_shape': (1, 1000),
'outputs': ['probs/BiasAdd:0'], 'outputs_shape': [1, 1000], 'precursor_nodes': ['probs/MatMul'],
'successor_nodes': ['probs/Softmax'], 'attributes': {}
}
}
fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"],
outputs_mapping=((0, 0),))
code = fragment()
init = code[0]
construct = code[1]
self.assertEqual(init, ['self.add = P.TensorAdd()',
'self.add_bias = Tensor(np.random.rand(*(1000,)).astype(np.float32))'])
self.assertEqual(construct, ['opt_add = self.add(x,self.add_bias)'])
self.assertEqual(fragment.get_outputs_by_idx(0), "opt_add")

def test_transpose(self):
"""Test transpose like operation's template."""
template = {
'var_0': {
'init': ['self.{var_0} = P.Transpose()'],
'construct': ['opt_{var_0} = self.{var_0}({inputs}, (0, 3, 1, 2))']
}
}
rewrite_data = {
'var_0': {
'operation': 'P.Transpose',
'output_type': 'tensor',
'variable_name': "transpose", 'inputs': ["x"], 'args': {}, 'weights': {},
'trainable_params': {}
}
}
fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"],
outputs_mapping=((0, 0),))
code = fragment()
init = code[0]
construct = code[1]
self.assertEqual(init, ['self.transpose = P.Transpose()'])
self.assertEqual(construct, ['opt_transpose = self.transpose(x, (0, 3, 1, 2))'])
self.assertEqual(fragment.get_outputs_by_idx(0), "opt_transpose")

def test_split(self):
"""Test split like operation's template."""
template = {
'var_0': {
'init': ['self.{var_0} = P.Split(axis={axis}, output_num={output_num})'],
'construct': ['opt_{var_0} = self.{var_0}({inputs})']
}
}
rewrite_data = {
'var_0': {
'operation': 'P.Split',
'variable_name': "split",
'output_type': 'array',
'inputs': ["x"],
'args': {"axis": 1, "output_num": 2}, 'weights': {},
'trainable_params': {}
}
}
fragment = Fragment(data_entity=rewrite_data, code_template=template, outputs=["opt_{var_0}"],
outputs_mapping=((0, 0),))
code = fragment()
init = code[0]
construct = code[1]
self.assertEqual(init, ['self.split = P.Split(axis=1, output_num=2)'])
self.assertEqual(construct, ['opt_split = self.split(x)'])
self.assertEqual(fragment.get_outputs_by_idx(0, 1), 'opt_split[1]')

Loading…
Cancel
Save