| @@ -349,9 +349,8 @@ class NodeStruct: | |||
| self.fragment.default_var["parameters"][trainable_param_postfix] = declare_statement | |||
| continue # not a shared weight, skip the rest | |||
| if onnx_name in self._global_context.repeated_weights_declaration.keys(): | |||
| continue # already declared, skip | |||
| self._global_context.repeated_weights_declaration[onnx_name] = declare_statement | |||
| if onnx_name not in self._global_context.repeated_weights_declaration.keys(): | |||
| self._global_context.repeated_weights_declaration[onnx_name] = declare_statement | |||
| # set template to mapper parameter rewritten. | |||
| shared_w_var_in_parent = self._get_shared_weight_var_names_from_parent(onnx_name=onnx_name) | |||
| @@ -13,11 +13,11 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Module rocessing for shared weights.""" | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct | |||
| from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct | |||
| class SharedWeightHelper: | |||
| """Helper function to process shared weights.""" | |||
| @@ -61,6 +61,9 @@ class SharedWeightHelper: | |||
| share_weight_name (str): The onnx name of the shared weights. | |||
| pub_module_identifier (list): The identifier of the public module the shared weight in. | |||
| """ | |||
| if not node.fragment.default_var.get(ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value): | |||
| # No weight shared operator, skip | |||
| return | |||
| parent_module = node.parent_module_struct | |||
| exit_flag = False | |||
| while True: | |||
| @@ -0,0 +1,50 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class OneHotMapper(ONNXToMindSporeMapper): | |||
| """OneHot mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "nn.OneHot" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| params = kwargs.get('params') | |||
| converted_params = {} | |||
| if params.get('axis'): | |||
| converted_params['axis'] = params.get('axis') | |||
| if kwargs.get('weights'): | |||
| weights = kwargs.get('weights') | |||
| depth = weights[0] | |||
| val = weights[1] | |||
| if depth and isinstance(depth.value, np.ndarray): | |||
| ms_depth = depth.value[0] | |||
| converted_params['depth'] = ms_depth | |||
| if val and isinstance(val.value, np.ndarray): | |||
| ms_off_val = val.value[0] | |||
| ms_on_val = val.value[1] | |||
| converted_params['off_value'] = ms_off_val | |||
| converted_params['on_value'] = ms_on_val | |||
| return converted_params | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class NegMapper(ONNXToMindSporeMapper): | |||
| """Neg mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Neg" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class ReciprocalMapper(ONNXToMindSporeMapper): | |||
| """Reciprocal mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Reciprocal" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @@ -13,7 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| @@ -37,17 +36,41 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| raw_params = kwargs.get("raw_params") | |||
| if raw_params.get('axes'): | |||
| axis = raw_params['axes'][0] if len(raw_params['axes']) == 1 else tuple(raw_params['axes']) | |||
| else: | |||
| axis = tuple() | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| args["axis"] = axis | |||
| init_tensor = f"self.{{{variable_slot}}}_axis = {{axis}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {axis})" | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | |||
| f"self.{{{variable_slot}}}_axis)" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template, init_tensor], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class RsqrtMapper(ONNXToMindSporeMapper): | |||
| """Rsqart mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Rsqrt" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @@ -15,6 +15,10 @@ | |||
| "onnx::Transpose": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.transpose_mapper.TransposeMapper", | |||
| "onnx::MatMul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.mat_mul_mapper.MatMulMapper", | |||
| "onnx::Softmax": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.softmax_mapper.SoftmaxMapper", | |||
| "onnx::OneHot": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.one_hot_mapper.OneHotMapper", | |||
| "onnx::Neg": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.neg_mapper.NegMapper", | |||
| "onnx::Reciprocal": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reciprocal_mapper.ReciprocalMapper", | |||
| "onnx::Rsqrt": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.rsqrt_mapper.RsqrtMapper", | |||
| "onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper", | |||
| "onnx::Slice": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.slice_mapper.SliceMapper", | |||
| "onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper", | |||