Browse Source

add missing mapper and minor logic optimization

tags/v1.2.0-rc1
liangtianshu 4 years ago
parent
commit
803b9689f1
8 changed files with 185 additions and 10 deletions
  1. +2
    -3
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  2. +4
    -1
      mindinsight/mindconverter/graph_based_converter/generator/shared_weights.py
  3. +50
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/one_hot_mapper.py
  4. +32
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/neg_mapper.py
  5. +32
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reciprocal_mapper.py
  6. +29
    -6
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py
  7. +32
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/rsqrt_mapper.py
  8. +4
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json

+ 2
- 3
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -349,9 +349,8 @@ class NodeStruct:
self.fragment.default_var["parameters"][trainable_param_postfix] = declare_statement
continue # not a shared weight, skip the rest

if onnx_name in self._global_context.repeated_weights_declaration.keys():
continue # already declared, skip
self._global_context.repeated_weights_declaration[onnx_name] = declare_statement
if onnx_name not in self._global_context.repeated_weights_declaration.keys():
self._global_context.repeated_weights_declaration[onnx_name] = declare_statement

# set template to mapper parameter rewritten.
shared_w_var_in_parent = self._get_shared_weight_var_names_from_parent(onnx_name=onnx_name)


+ 4
- 1
mindinsight/mindconverter/graph_based_converter/generator/shared_weights.py View File

@@ -13,11 +13,11 @@
# limitations under the License.
# ==============================================================================
"""Module rocessing for shared weights."""
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct
from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct


class SharedWeightHelper:
"""Helper function to process shared weights."""

@@ -61,6 +61,9 @@ class SharedWeightHelper:
share_weight_name (str): The onnx name of the shared weights.
pub_module_identifier (list): The identifier of the public module the shared weight in.
"""
if not node.fragment.default_var.get(ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value):
# No weight shared operator, skip
return
parent_module = node.parent_module_struct
exit_flag = False
while True:


+ 50
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/one_hot_mapper.py View File

@@ -0,0 +1,50 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
import numpy as np

from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper


class OneHotMapper(ONNXToMindSporeMapper):
"""OneHot mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "nn.OneHot"

@staticmethod
def _convert_params(**kwargs):
params = kwargs.get('params')
converted_params = {}
if params.get('axis'):
converted_params['axis'] = params.get('axis')
if kwargs.get('weights'):
weights = kwargs.get('weights')
depth = weights[0]
val = weights[1]
if depth and isinstance(depth.value, np.ndarray):
ms_depth = depth.value[0]
converted_params['depth'] = ms_depth
if val and isinstance(val.value, np.ndarray):
ms_off_val = val.value[0]
ms_on_val = val.value[1]
converted_params['off_value'] = ms_off_val
converted_params['on_value'] = ms_on_val
return converted_params

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

+ 32
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/neg_mapper.py View File

@@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper


class NegMapper(ONNXToMindSporeMapper):
"""Neg mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.Neg"

@staticmethod
def _convert_params(**kwargs):
return dict()

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

+ 32
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reciprocal_mapper.py View File

@@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper


class ReciprocalMapper(ONNXToMindSporeMapper):
"""Reciprocal mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.Reciprocal"

@staticmethod
def _convert_params(**kwargs):
return dict()

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

+ 29
- 6
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py View File

@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper

@@ -37,17 +36,41 @@ class ReduceMeanMapper(ONNXToMindSporeMapper):

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
op = kwargs.get("operation")
args = kwargs.get("converted_params")
raw_params = kwargs.get("raw_params")
if raw_params.get('axes'):
axis = raw_params['axes'][0] if len(raw_params['axes']) == 1 else tuple(raw_params['axes'])
else:
axis = tuple()
variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"

args["axis"] = axis
init_tensor = f"self.{{{variable_slot}}}_axis = {{axis}}"

construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {axis})"
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \
f"self.{{{variable_slot}}}_axis)"

template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template, init_tensor],
TemplateKeywords.CONSTRUCT.value: [construct_template]
}
}
exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: [],
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 32
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/rsqrt_mapper.py View File

@@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper


class RsqrtMapper(ONNXToMindSporeMapper):
"""Rsqart mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.Rsqrt"

@staticmethod
def _convert_params(**kwargs):
return dict()

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

+ 4
- 0
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -15,6 +15,10 @@
"onnx::Transpose": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.transpose_mapper.TransposeMapper",
"onnx::MatMul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.mat_mul_mapper.MatMulMapper",
"onnx::Softmax": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.softmax_mapper.SoftmaxMapper",
"onnx::OneHot": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.one_hot_mapper.OneHotMapper",
"onnx::Neg": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.neg_mapper.NegMapper",
"onnx::Reciprocal": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reciprocal_mapper.ReciprocalMapper",
"onnx::Rsqrt": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.rsqrt_mapper.RsqrtMapper",
"onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper",
"onnx::Slice": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.slice_mapper.SliceMapper",
"onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper",


Loading…
Cancel
Save