| @@ -22,8 +22,6 @@ from typing import List, Tuple, Mapping | |||
| import numpy as np | |||
| from mindspore.train.serialization import save_checkpoint | |||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ | |||
| UnknownModelError, CheckPointGenerationError, WeightMapGenerationError | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| @@ -163,6 +161,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||
| except (IOError, FileExistsError) as error: | |||
| raise ReportGenerationError(str(error)) | |||
| save_checkpoint = getattr(import_module("mindspore.train.serialization"), "save_checkpoint") | |||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) | |||
| try: | |||
| if os.path.exists(ckpt_file_path): | |||
| @@ -15,6 +15,8 @@ | |||
| """Constant definition.""" | |||
| from enum import Enum, unique | |||
| import numpy as np | |||
| SEPARATOR_IN_ONNX_OP = "::" | |||
| SEPARATOR_IN_SCOPE = "/" | |||
| SEPARATOR_BTW_NAME_AND_ID = "_" | |||
| @@ -47,6 +49,25 @@ ONNXOPTIMIZER_MAX_VER = "0.1.2" | |||
| TORCH_MIN_VER = "1.5.0" | |||
| DTYPE_MAP = { | |||
| 1: np.float32, | |||
| 2: np.uint8, | |||
| 3: np.int8, | |||
| 4: np.uint16, | |||
| 5: np.int16, | |||
| 6: np.int32, | |||
| 7: np.int64, | |||
| 8: str, | |||
| 9: bool, | |||
| 10: np.float16, | |||
| 11: np.double, | |||
| 12: np.uint32, | |||
| 13: np.uint64, | |||
| 14: np.complex64, | |||
| 15: np.complex128, | |||
| 16: None | |||
| } | |||
| @unique | |||
| class TemplateKeywords(Enum): | |||
| @@ -99,11 +99,13 @@ def torch_installation_validation(func): | |||
| f"are required when using graph based scripts converter, and PyTorch version must " \ | |||
| f"be consisted with model generation runtime." | |||
| output_queue = mp.Queue() | |||
| process = mp.Process(target=torch_version_satisfied, args=(output_queue,)) | |||
| process.start() | |||
| torch_version_validation = output_queue.get() | |||
| process.join() | |||
| if not error_info: | |||
| output_queue = mp.Queue() | |||
| process = mp.Process(target=torch_version_satisfied, args=(output_queue,)) | |||
| process.start() | |||
| torch_version_validation = output_queue.get() | |||
| process.join() | |||
| if error_info: | |||
| _print_error(RuntimeIntegrityError(error_info)) | |||
| sys.exit(0) | |||
| @@ -268,6 +270,9 @@ def main_graph_base_converter(file_config): | |||
| if graph_path.endswith("pth") and not file_config.get("input_nodes", []) and \ | |||
| file_config.get("shape") and len(file_config.get("shape", ())) == 1: | |||
| file_config['input_nodes'] = ["input.1"] | |||
| else: | |||
| check_params = ['input_nodes', 'output_nodes'] | |||
| check_params_exist(check_params, file_config) | |||
| if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len( | |||
| set(file_config.get("input_nodes", []))): | |||
| @@ -280,8 +285,6 @@ def main_graph_base_converter(file_config): | |||
| if frame_type == FrameworkType.PYTORCH.value: | |||
| if graph_path.endswith('.onnx'): | |||
| check_params = ['input_nodes', 'output_nodes'] | |||
| check_params_exist(check_params, file_config) | |||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | |||
| input_nodes=input_nodes, | |||
| output_nodes=file_config['output_nodes'], | |||
| @@ -294,8 +297,6 @@ def main_graph_base_converter(file_config): | |||
| output_folder=file_config['outfile_dir'], | |||
| report_folder=file_config['report_dir']) | |||
| elif frame_type == FrameworkType.TENSORFLOW.value: | |||
| check_params = ['input_nodes', 'output_nodes'] | |||
| check_params_exist(check_params, file_config) | |||
| graph_based_converter_tf_to_ms(graph_path=graph_path, | |||
| input_nodes=input_nodes, | |||
| output_nodes=file_config['output_nodes'], | |||
| @@ -192,3 +192,17 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| result = weight.value | |||
| break | |||
| return result | |||
| @staticmethod | |||
| def _find_location_by_index(loc_index, weights_list): | |||
| """Find weight location in inputs of Node.""" | |||
| result = -1 | |||
| if loc_index < 0: | |||
| return weights_list[loc_index].location | |||
| for idx, weight in enumerate(weights_list): | |||
| if idx == loc_index: | |||
| result = weight.location | |||
| break | |||
| return result | |||
| @@ -63,8 +63,6 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params", dict()) | |||
| @@ -74,10 +72,9 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| kernel_shape = kwargs['raw_params']['kernel_shape'] | |||
| dilations = kwargs['raw_params'].get('dilations', (1, 1)) | |||
| strides = kwargs['raw_params']['strides'] | |||
| onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):] | |||
| if np.all(np.array(ms_opt_shape) == np.array(onnx_opt_shape)): | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| @@ -109,7 +106,8 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: dict() | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @staticmethod | |||
| @@ -122,9 +120,12 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| if np.any(np.array(ms_opt_shape) > np.array(onnx_opt_shape)): | |||
| raise ValueError(f"ms_opt_shape[{ms_opt_shape}] should be no larger than onnx_opt_shape[{onnx_opt_shape}].") | |||
| shape_diff = np.subtract((np.array(onnx_opt_shape) - 1) * np.array(strides), | |||
| np.subtract(np.array(onnx_ipt_shape), | |||
| (np.array(kernel_shape) - 1) * np.array(dilations) + 1)).tolist() | |||
| if np.all(np.array(ms_opt_shape) == np.array(onnx_opt_shape)): | |||
| shape_diff = np.zeros(len(ms_opt_shape)).astype(np.int).tolist() | |||
| else: | |||
| shape_diff = np.subtract((np.array(onnx_opt_shape) - 1) * np.array(strides), | |||
| np.subtract(np.array(onnx_ipt_shape), | |||
| (np.array(kernel_shape) - 1) * np.array(dilations) + 1)).tolist() | |||
| zero_pad_single = (0, 0) | |||
| paddings = [zero_pad_single] | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class TanhMapper(ONNXToMindSporeMapper): | |||
| """Tanh mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "nn.Tanh" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @@ -13,6 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \ | |||
| WeightType | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| @@ -23,7 +25,7 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.TensorAdd" | |||
| return "P.Add" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| @@ -31,9 +33,11 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs['weights'] | |||
| bias = AddMapper._find_val_by_index(0, weights) | |||
| return {'bias': {'data': bias, 'type': WeightType.PARAMETER.value}} | |||
| weights = kwargs.get('weights', list()) | |||
| tensor = AddMapper._find_val_by_index(0, weights) | |||
| if isinstance(tensor, np.ndarray) and tensor.shape: | |||
| return {'bias': {'data': tensor, 'type': WeightType.PARAMETER.value}} | |||
| return dict() | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| @@ -49,17 +53,28 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| tensor = AddMapper._find_val_by_index(0, weights) | |||
| bias_shape = tensor.shape | |||
| bias_dtype = tensor.dtype | |||
| bias_location = AddMapper._find_location_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| args["bias_shape"] = tensor.shape | |||
| args["bias_dtype"] = tensor.dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}})), " \ | |||
| f"name=None)" | |||
| inputs_in_construct = [f"{{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}"] | |||
| if bias_location != -1: | |||
| inputs_in_construct.insert(bias_location, f"self.{{{variable_slot}}}_bias") | |||
| if bias_shape: | |||
| args["bias_shape"] = bias_shape | |||
| args["bias_dtype"] = bias_dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}})), " \ | |||
| f"name=None)" | |||
| else: | |||
| args["bias_value"] = tensor.tolist() | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = {{bias_value}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||
| f"self.{{{variable_slot}}}_bias)" | |||
| f"({', '.join(inputs_in_construct)})" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template, init_tensor], | |||
| @@ -0,0 +1,88 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class CastMapper(ONNXToMindSporeMapper): | |||
| """Cast mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Cast" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params", dict()) | |||
| params = kwargs['raw_params'] | |||
| weights = kwargs.get("weights") | |||
| to = params["to"] | |||
| type_dict = {1: 'mindspore.float32', | |||
| 2: 'mindspore.uint8', | |||
| 3: 'mindspore.int8', | |||
| 4: 'mindspore.uint16', | |||
| 5: 'mindspore.int16', | |||
| 6: 'mindspore.int32', | |||
| 7: 'mindspore.int64', | |||
| 8: 'mindspore.string', | |||
| 9: 'mindspore.bool_', | |||
| 10: 'mindspore.float16', | |||
| 11: 'mindspore.double', | |||
| 12: 'mindspore.uint32', | |||
| 13: 'mindspore.uint64', | |||
| 14: 'UNSUPPORTED', | |||
| 15: 'UNSUPPORTED', | |||
| 16: 'UNSUPPORTED'} | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| args["to"] = type_dict[to] | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}()" | |||
| init_to = f"self.{{{variable_slot}}}_to = {{to}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | |||
| f"self.{{{variable_slot}}}_to)" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template, init_to], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: dict() | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import WeightType, ExchangeMessageKeywords, \ | |||
| TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class DivMapper(ONNXToMindSporeMapper): | |||
| """Div mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Div" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs.get('weights', list()) | |||
| tensor = DivMapper._find_val_by_index(0, weights) | |||
| if isinstance(tensor, np.ndarray) and tensor.shape: | |||
| return {'w': {'data': tensor, 'type': WeightType.PARAMETER.value}} | |||
| return dict() | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| trainable_params = kwargs.get("trainable_params", dict()) | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| tensor = DivMapper._find_val_by_index(0, weights) | |||
| w_shape = tensor.shape | |||
| w_dtype = tensor.dtype | |||
| w_location = DivMapper._find_location_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}()" | |||
| inputs_in_construct = [f"{{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}"] | |||
| if w_location != -1: | |||
| inputs_in_construct.insert(w_location, f"self.{{{variable_slot}}}_w") | |||
| if w_shape: | |||
| args["w_shape"] = w_shape | |||
| args["w_dtype"] = w_dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{w_shape}}).astype(np.{{w_dtype}})), " \ | |||
| f"name=None)" | |||
| else: | |||
| args["w_value"] = tensor.tolist() | |||
| init_tensor = f"self.{{{variable_slot}}}_w = {{w_value}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({', '.join(inputs_in_construct)})" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -0,0 +1,132 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import re | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \ | |||
| WeightType, BLANK_SYM | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class EinSumMapper(ONNXToMindSporeMapper): | |||
| """EinSum mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| equation = kwargs['params']['equation'] | |||
| if not EinSumMapper.equation_check(equation): | |||
| return None | |||
| return "P.MatMul" | |||
| @staticmethod | |||
| def equation_check(equation): | |||
| """ | |||
| Check equation validation. | |||
| Only support equation `bfxxx, xxxh -> bfh`. | |||
| Args: | |||
| equation (Union[str, bytes]): Equation of EinSum. | |||
| Returns: | |||
| bool, True if equation is format of `bfxxx, xxxh -> bfh`, otherwise False. | |||
| """ | |||
| equation = convert_bytes_string_to_string(equation) | |||
| equation = equation.replace(BLANK_SYM, '').split('->') | |||
| if len(equation) != 2: | |||
| return False | |||
| equation_left_list = equation[0].split(',') | |||
| if len(equation_left_list) != 2: | |||
| return False | |||
| equation_right = equation[1] | |||
| pattern = ''.join([s for s in equation_left_list[0] if s in equation_left_list[1]]) | |||
| output_first = re.sub(pattern, '', equation_left_list[0]) | |||
| output_second = re.sub(pattern, '', equation_left_list[1]) | |||
| output = ''.join((output_first, output_second)).replace(BLANK_SYM, '') | |||
| return output == equation_right | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs.get('weights', list()) | |||
| tensor = EinSumMapper._find_val_by_index(0, weights) | |||
| if isinstance(tensor, np.ndarray) and tensor.shape: | |||
| return {'weight': {'data': tensor, 'type': WeightType.PARAMETER.value}} | |||
| return dict() | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| input_shape = kwargs["raw_params"]["input_shape"] | |||
| trainable_params = kwargs.get("trainable_params", dict()) | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| variable_slot = "var_0" | |||
| init_template_list = [f"self.{{{variable_slot}}} = {op}()"] | |||
| default_shape = input_shape[:2] | |||
| inputs_in_construct = [ | |||
| f"{{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}" | |||
| f".view({default_shape[0] * default_shape[1]}, -1)"] | |||
| if weights: | |||
| tensor = EinSumMapper._find_val_by_index(0, weights) | |||
| args["weight_shape"] = tensor.shape | |||
| args["weight_dtype"] = tensor.dtype | |||
| weight_location = EinSumMapper._find_location_by_index(0, weights) | |||
| init_template_list.append( | |||
| f"self.{{{variable_slot}}}_weight = " | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}})), name=None)") | |||
| default_shape += (tensor.shape[-1],) | |||
| inputs_in_construct.insert(weight_location, f"self.{{{variable_slot}}}_weight.view(-1, {tensor.shape[-1]})") | |||
| construct_template = f"opt_{{{variable_slot}}} = " \ | |||
| f"self.{{{variable_slot}}}({', '.join(inputs_in_construct)}).view{default_shape}" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: init_template_list, | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class ErfMapper(ONNXToMindSporeMapper): | |||
| """Erf mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Erf" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @@ -0,0 +1,107 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class ExpandDimsMapper(ONNXToMindSporeMapper): | |||
| """Expand_dims mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.ExpandDims" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| op = kwargs.get('operation') | |||
| args = kwargs.get('converted_params', dict()) | |||
| params = kwargs['raw_params'] # opset 11, axes is in attributes, and is a list. | |||
| weights = kwargs.get('weights') # opset 12, axes is in inputs and is a tensor. | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| if weights: | |||
| axes = ExpandDimsMapper._find_val_by_index(0, weights).tolist() | |||
| else: | |||
| axes = params['axes'] | |||
| variable_slot = 'var_0' | |||
| init_template_list = [f"self.{{{variable_slot}}} = {op}()"] | |||
| construct_template_list = list() | |||
| # Type of `axes` in torch operator `Unsqueeze` is list, | |||
| # while type of `axis` in MindSpore operator `ExpandDims` is int. | |||
| # As a result, if length of `axes` is 1, one operator in MindSpore is required to replace that in torch, | |||
| # otherwise, a set of operators in MindSpore are required to replace this operator in torch. | |||
| if len(axes) == 1: | |||
| args["axis"] = axes[0] | |||
| init_template_list.append(f"self.{{{variable_slot}}}_axis = {{axis}}") | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | |||
| f"self.{{{variable_slot}}}_axis)" | |||
| construct_template_list.append(construct_template) | |||
| else: | |||
| for idx, axis in enumerate(axes): | |||
| if not construct_template_list: | |||
| args[f"axis_{idx}"] = axis | |||
| init_template = f"self.{{{variable_slot}}}_{idx}_axis = {{axis_{idx}}}" | |||
| construct_template = f"opt_{{{variable_slot}}}_{idx} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | |||
| f"self.{{{variable_slot}}}_{idx}_axis)" | |||
| elif idx == len(axes) - 1: | |||
| args["axis"] = axis | |||
| init_template = f"self.{{{variable_slot}}}_axis = {{axis}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{variable_slot}}}_{idx - 1}, self.{{{variable_slot}}}_axis)" | |||
| else: | |||
| args[f"axis_{idx}"] = axis | |||
| init_template = f"self.{{{variable_slot}}}_{idx}_axis = {{axis_{idx}}}" | |||
| construct_template = f"opt_{{{variable_slot}}}_{idx} = self.{{{variable_slot}}}" \ | |||
| f"({{{variable_slot}}}_{idx - 1}, self.{{{variable_slot}}}_{idx}_axis)" | |||
| init_template_list.append(init_template) | |||
| construct_template_list.append(construct_template) | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: init_template_list, | |||
| TemplateKeywords.CONSTRUCT.value: construct_template_list | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: dict() | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -0,0 +1,99 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \ | |||
| WeightType | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class GatherMapper(ONNXToMindSporeMapper): | |||
| """Gather mapper""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Gather" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| axis = params.get('axis', 0) | |||
| return {'axis': axis} | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs.get('weights', list()) | |||
| tensor = GatherMapper._find_val_by_index(0, weights) | |||
| if isinstance(tensor, np.ndarray) and tensor.shape: | |||
| return {'input_weight': {'data': tensor, 'type': WeightType.PARAMETER.value}} | |||
| return dict() | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| op = kwargs.get('operation') | |||
| args = kwargs.get('converted_params') | |||
| weights = kwargs.get('weights') | |||
| trainable_params = kwargs.get('trainable_params', dict()) | |||
| if not op: | |||
| raise ValueError('Can not get MindSpore operation name.') | |||
| tensor = GatherMapper._find_val_by_index(0, weights) | |||
| weight_shape = tensor.shape | |||
| weight_dtype = tensor.dtype | |||
| weight_location = GatherMapper._find_location_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}()" | |||
| inputs_in_construct = [f"{{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}"] | |||
| if weight_location != -1: | |||
| inputs_in_construct.insert(weight_location, f"self.{{{variable_slot}}}_input_weight") | |||
| if weight_shape: | |||
| args['weight_shape'] = weight_shape | |||
| args['weight_dtype'] = weight_dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_input_weight = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}})), " \ | |||
| f"name=None)" | |||
| else: | |||
| args['weight_value'] = tensor.tolist() | |||
| init_tensor = f"self.{{{variable_slot}}}_input_weight = Tensor(np.array({{weight_value}}))" | |||
| init_axis = f"self.{{{variable_slot}}}_axis = {{axis}}" | |||
| construct_axis = f"opt_{{{variable_slot}}}_axis = self.{{{variable_slot}}}_axis" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({', '.join(inputs_in_construct)}, opt_{{{variable_slot}}}_axis)" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_tensor, init_axis, init_template], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_axis, construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -13,6 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \ | |||
| WeightType | |||
| @@ -32,9 +34,9 @@ class MulMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs['weights'] | |||
| if weights: | |||
| tensor = MulMapper._find_val_by_index(0, weights) | |||
| weights = kwargs.get('weights', list()) | |||
| tensor = MulMapper._find_val_by_index(0, weights) | |||
| if isinstance(tensor, np.ndarray) and tensor.shape: | |||
| return {'w': {'data': tensor, 'type': WeightType.PARAMETER.value}} | |||
| return dict() | |||
| @@ -45,23 +47,49 @@ class MulMapper(ONNXToMindSporeMapper): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| trainable_params = kwargs.get("trainable_params", dict()) | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| tensor = MulMapper._find_val_by_index(0, weights) | |||
| w_shape = tensor.shape | |||
| w_dtype = tensor.dtype | |||
| w_location = MulMapper._find_location_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}()" | |||
| args["w_shape"] = tensor.shape | |||
| args["w_dtype"] = tensor.dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{w_shape}}).astype(np.{{w_dtype}})), " \ | |||
| f"name=None)" | |||
| inputs_in_construct = [f"{{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}"] | |||
| if w_location != -1: | |||
| inputs_in_construct.insert(w_location, f"self.{{{variable_slot}}}_w") | |||
| if w_shape: | |||
| args["w_shape"] = w_shape | |||
| args["w_dtype"] = w_dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{w_shape}}).astype(np.{{w_dtype}})), " \ | |||
| f"name=None)" | |||
| else: | |||
| args["w_value"] = tensor.tolist() | |||
| init_tensor = f"self.{{{variable_slot}}}_w = {{w_value}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||
| f"self.{{{variable_slot}}}_w)" | |||
| f"({', '.join(inputs_in_construct)})" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \ | |||
| WeightType | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class PowMapper(ONNXToMindSporeMapper): | |||
| """Pow mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Pow" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs.get('weights', list()) | |||
| tensor = PowMapper._find_val_by_index(0, weights) | |||
| if isinstance(tensor, np.ndarray) and tensor.shape: | |||
| return {'input_weight': {'data': tensor, 'type': WeightType.PARAMETER.value}} | |||
| return dict() | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| trainable_params = kwargs.get("trainable_params", dict()) | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| tensor = PowMapper._find_val_by_index(0, weights) | |||
| input_shape = tensor.shape | |||
| input_dtype = tensor.dtype | |||
| input_location = PowMapper._find_location_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}()" | |||
| inputs_in_construct = [f"{{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}"] | |||
| if input_location != -1: | |||
| inputs_in_construct.insert(input_location, f"self.{{{variable_slot}}}_input_weight") | |||
| if input_shape: | |||
| args["input_shape"] = input_shape | |||
| args["input_dtype"] = input_dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_input_weight = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{input_shape}}).astype(np.{{input_dtype}})), " \ | |||
| f"name=None)" | |||
| else: | |||
| args["input_value"] = tensor.tolist() | |||
| init_tensor = f"self.{{{variable_slot}}}_input_weight = {{input_value}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({', '.join(inputs_in_construct)})" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -28,7 +28,7 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| keep_dims = not params['keepdims'] == 0 | |||
| keep_dims = not params.get('keepdims', 1) == 0 | |||
| return {'keep_dims': keep_dims} | |||
| @staticmethod | |||
| @@ -23,9 +23,6 @@ class ResizeMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| params = kwargs.get("params") | |||
| onnx_coordinate_transform = params.get("coordinate_transformation_mode") | |||
| if onnx_coordinate_transform is not None: | |||
| onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform) | |||
| interpolation_mode = params.get("mode") | |||
| if interpolation_mode is not None: | |||
| @@ -45,24 +42,14 @@ class ResizeMapper(ONNXToMindSporeMapper): | |||
| def _convert_params(**kwargs): | |||
| weights = kwargs.get("weights") | |||
| params = kwargs.get("params") | |||
| # Set default params | |||
| align_corners = False | |||
| if len(weights) > 3: | |||
| raise ValueError("For resize, `weights` length less or equal to 3.") | |||
| onnx_coordinate_transform = params.get("coordinate_transformation_mode") | |||
| if onnx_coordinate_transform is not None: | |||
| onnx_coordinate_transform = convert_bytes_string_to_string(onnx_coordinate_transform) | |||
| if onnx_coordinate_transform == "align_corners" or "half_pixel" in onnx_coordinate_transform: | |||
| align_corners = True | |||
| # Get requested size for resize | |||
| size = ResizeMapper._find_val_by_index(-1, weights)[-2:].tolist() | |||
| size = params["output_shape"][-2:] | |||
| return {"size": tuple(size), | |||
| "align_corners": align_corners} | |||
| return {"size": tuple(size)} | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| @@ -13,6 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| @@ -23,7 +25,7 @@ class SliceMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Slice" | |||
| return "P.StridedSlice" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| @@ -37,15 +39,53 @@ class SliceMapper(ONNXToMindSporeMapper): | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| weights = [weight.value for weight in kwargs.get('weights')] # start, end, axis | |||
| opt_shape = kwargs["raw_params"]["output_shape"] | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params", dict()) | |||
| weights = kwargs.get("weights") | |||
| params = kwargs["raw_params"] | |||
| ipt_shape = params["input_shape"] | |||
| starts = SliceMapper._find_val_by_index(0, weights) | |||
| ends = SliceMapper._find_val_by_index(1, weights) | |||
| axes = SliceMapper._find_val_by_index(2, weights, np.array([i for i in range(len(ipt_shape))])) | |||
| steps = SliceMapper._find_val_by_index(3, weights, np.array([1 for _ in range(len(ipt_shape))])) | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| if not weights: | |||
| raise ValueError("Cannot get required params from slice.") | |||
| starts = sorted(zip(weights[0].tolist(), weights[2].tolist()), key=lambda x: x[1], reverse=False) | |||
| if axes.shape != (1,): | |||
| ordered_begin = sorted(zip(starts.tolist(), axes.tolist()), key=lambda x: x[1], reverse=False) | |||
| ordered_end = sorted(zip(ends.tolist(), axes.tolist()), key=lambda x: x[1], reverse=False) | |||
| ordered_strides = sorted(zip(steps.tolist(), axes.tolist()), key=lambda x: x[1], reverse=False) | |||
| begin = [i[0] for i in ordered_begin] | |||
| end = [min(i[0], ipt_shape[i[1]]) for i in ordered_end] | |||
| strides = [i[0] for i in ordered_strides] | |||
| else: | |||
| axis = axes.tolist()[0] | |||
| begin = [0 for _ in range(len(ipt_shape))] | |||
| end = list(ipt_shape) | |||
| strides = [1 for _ in range(len(ipt_shape))] | |||
| begin[axis] = starts.tolist()[0] | |||
| end[axis] = min(ends.tolist()[0], end[axis]) | |||
| strides[axis] = steps.tolist()[0] | |||
| args["begin"] = tuple(begin) | |||
| args["end"] = tuple(end) | |||
| args["strides"] = tuple(strides) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}()" | |||
| init_begin = f"self.{{{variable_slot}}}_begin = {{begin}}" | |||
| init_end = f"self.{{{variable_slot}}}_end = {{end}}" | |||
| init_strides = f"self.{{{variable_slot}}}_strides = {{strides}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | |||
| f"{tuple([i[0] for i in starts])}, {tuple(opt_shape)})" | |||
| f"self.{{{variable_slot}}}_begin, self.{{{variable_slot}}}_end, " \ | |||
| f"self.{{{variable_slot}}}_strides)" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_begin, init_end, init_strides], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class SqrtMapper(ONNXToMindSporeMapper): | |||
| """Sqrt mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Sqrt" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @@ -0,0 +1,98 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.constant import WeightType, ExchangeMessageKeywords, \ | |||
| TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class SubMapper(ONNXToMindSporeMapper): | |||
| """Sub mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Sub" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs.get('weights', list()) | |||
| tensor = SubMapper._find_val_by_index(0, weights) | |||
| if isinstance(tensor, np.ndarray) and tensor.shape: | |||
| return {'bias': {'data': tensor, 'type': WeightType.PARAMETER.value}} | |||
| return dict() | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| trainable_params = kwargs.get('trainable_params', dict()) | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| tensor = SubMapper._find_val_by_index(0, weights) | |||
| bias_shape = tensor.shape | |||
| bias_dtype = tensor.dtype | |||
| bias_location = SubMapper._find_location_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}()" | |||
| inputs_in_construct = [f"{{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}"] | |||
| if bias_location != -1: | |||
| inputs_in_construct.insert(bias_location, f"self.{{{variable_slot}}}_bias") | |||
| if bias_shape: | |||
| args["bias_shape"] = bias_shape | |||
| args["bias_dtype"] = bias_dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}})), " \ | |||
| f"name=None)" | |||
| else: | |||
| args["bias_value"] = tensor.tolist() | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = {{bias_value}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({', '.join(inputs_in_construct)})" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template, init_tensor], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -20,5 +20,15 @@ | |||
| "onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper", | |||
| "onnx::Sigmoid": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.sigmoid_mapper.SigmoidMapper", | |||
| "onnx::Split": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.split_mapper.SplitMapper", | |||
| "onnx::Resize": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.resize_mapper.ResizeMapper" | |||
| "onnx::Resize": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.resize_mapper.ResizeMapper", | |||
| "onnx::Gather": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.gather_mapper.GatherMapper", | |||
| "onnx::Sub": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.sub_mapper.SubMapper", | |||
| "onnx::Sqrt": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.sqrt_mapper.SqrtMapper", | |||
| "onnx::Div": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.div_mapper.DivMapper", | |||
| "onnx::Unsqueeze": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.expand_dims_mapper.ExpandDimsMapper", | |||
| "onnx::Cast": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cast_mapper.CastMapper", | |||
| "onnx::Erf": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.erf_mapper.ErfMapper", | |||
| "onnx::Pow": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.pow_mapper.PowMapper", | |||
| "onnx::Einsum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.einsum_mapper.EinSumMapper", | |||
| "onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper" | |||
| } | |||
| @@ -125,12 +125,12 @@ class OnnxGraph(Graph): | |||
| node.scope_name = scope_name_list[ind] | |||
| inputs = node.input_name_list | |||
| # check each input from node or tensors | |||
| for i in inputs: | |||
| for idx, i in enumerate(inputs): | |||
| if i in model_data.tensors_dict: | |||
| tensor = model_data.tensors_dict[i] | |||
| t_name = tensor.name | |||
| t_value = tensor.to_array() | |||
| node_weights.append(NodeWeight(t_name, t_value)) | |||
| node_weights.append(NodeWeight(t_name, t_value, idx)) | |||
| self._nodes_collection[node_name] = OnnxGraphNode(node, node_weights) | |||
| self._nodes_record[node_name] = node_name | |||
| @@ -202,8 +202,12 @@ class OnnxGraph(Graph): | |||
| else: | |||
| onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) | |||
| onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] | |||
| for ipt in input_nodes: | |||
| if ipt not in onnx_inputs: | |||
| raise ModelLoadingError(f"input nodes({input_nodes}) is not " | |||
| f"in model inputs ({onnx_inputs}).") | |||
| invalid_input_node_name = list() | |||
| for node_name in input_nodes.keys(): | |||
| if node_name not in onnx_inputs: | |||
| invalid_input_node_name.append(node_name) | |||
| if invalid_input_node_name: | |||
| raise ModelLoadingError( | |||
| f"input nodes({invalid_input_node_name}) is not in model inputs ({onnx_inputs}).") | |||
| return onnx_model | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define ONNX related operations.""" | |||
| import itertools | |||
| import re | |||
| import abc | |||
| from importlib import import_module | |||
| @@ -26,7 +27,7 @@ from mindinsight.mindconverter.graph_based_converter.common.global_context impor | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer import OnnxSimplify | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | |||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | |||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL, DTYPE_MAP | |||
| from mindinsight.mindconverter.common.exceptions import GraphInitError | |||
| @@ -343,22 +344,52 @@ class OnnxDataLoader: | |||
| i_type = group_match.group('type') | |||
| i_dim_str = group_match.group('dim_str') | |||
| for dim in i_dim_str.split('x'): | |||
| if not dim.isdigit(): | |||
| raise ValueError("Unknown output shape.") | |||
| return i_name, i_type, i_dim_str | |||
| if not self.inferred_model: | |||
| return | |||
| value_info = self.inferred_model.graph.value_info | |||
| node_without_output_shape = dict() | |||
| for v in value_info: | |||
| for v in itertools.chain(value_info, self.inferred_model.graph.output): | |||
| try: | |||
| readable_info = onnx.helper.printable_value_info(v) | |||
| (node_name, node_type, node_dim) = _parse_value_info_re(readable_info) | |||
| except (AssertionError, ValueError, AttributeError) as _: | |||
| node_name, node_type, node_dim = self._parse_value_info_manually(v) | |||
| # `node_dim` could be "" or "scalar". | |||
| node_without_output_shape[node_name] = {'node_type': node_type, 'output_name': v.name} | |||
| self.value_info_dict[node_name] = (node_type, node_dim) | |||
| inferred_outputs_name = [node_inst['output_name'] for node_inst in list(node_without_output_shape.values())] | |||
| inferred_outputs = dict() | |||
| if inferred_outputs_name: | |||
| inferred_outputs = self._get_outputs_using_onnxruntime(inferred_outputs_name) | |||
| for node_name, node_inst in node_without_output_shape.items(): | |||
| node_type = node_inst['node_type'] | |||
| output_name = node_inst['output_name'] | |||
| node_dim = 'x'.join(str(shape_axis) for shape_axis in inferred_outputs[output_name].shape) | |||
| self.value_info_dict[node_name] = (node_type, node_dim) | |||
| def _get_outputs_using_onnxruntime(self, output_nodes_name): | |||
| """Get outputs using onnxruntime.""" | |||
| onnx_inputs = self.inferred_model.graph.input | |||
| dtype_dict = dict() | |||
| for onnx_input in onnx_inputs: | |||
| dtype_dict[onnx_input.name] = DTYPE_MAP[onnx_input.type.tensor_type.elem_type] | |||
| feed_dict = build_feed_dict(self.inferred_model, self.input_nodes) | |||
| outputs_infer = fetch_output_from_onnx_model(self.model, feed_dict, output_nodes_name) | |||
| return outputs_infer | |||
| def _parse_nodes(self): | |||
| """Parse each onnx nodes in the model.""" | |||
| nodes_topo_idx = [] | |||
| @@ -569,7 +600,7 @@ class OnnxDataLoader: | |||
| if nd_inst.op_type == "Resize": | |||
| # Find the size params | |||
| to_shape = nd_inst.input_name_list[3] | |||
| to_shape = nd_inst.input_name_list[-1] | |||
| if to_shape in self.tensors_dict: | |||
| return | |||
| @@ -581,9 +612,10 @@ class OnnxDataLoader: | |||
| class NodeWeight: | |||
| """Node weight struct.""" | |||
| def __init__(self, weight_name, weight_value): | |||
| def __init__(self, weight_name, weight_value, weight_location): | |||
| self._weight_name = weight_name | |||
| self._weight_value = weight_value | |||
| self._weight_location = weight_location | |||
| @property | |||
| def name(self): | |||
| @@ -592,3 +624,7 @@ class NodeWeight: | |||
| @property | |||
| def value(self): | |||
| return self._weight_value | |||
| @property | |||
| def location(self): | |||
| return self._weight_location | |||
| @@ -47,9 +47,8 @@ class PyTorchGraphParser(GraphParser): | |||
| raise error | |||
| try: | |||
| sample_shape = list(kwargs.get("input_nodes").values())[0] | |||
| onnx_model_sim = cls._convert_pytorch_graph_to_onnx( | |||
| model_path, sample_shape, opset_version=11) | |||
| model_path, kwargs['input_nodes'], opset_version=11) | |||
| return onnx_model_sim | |||
| except ModuleNotFoundError: | |||
| error_msg = "Cannot find model scripts in system path, " \ | |||
| @@ -58,19 +57,19 @@ class PyTorchGraphParser(GraphParser): | |||
| raise error | |||
| @staticmethod | |||
| def _convert_pytorch_graph_to_onnx(model_path, sample_shape, opset_version=None): | |||
| def _convert_pytorch_graph_to_onnx(model_path, input_nodes, opset_version=None): | |||
| """ | |||
| Convert Pytorch model to ONNX model. | |||
| Args: | |||
| model_path (str): Path to the Pytorch model. | |||
| sample_shape (tuple): Input shape to generate onnx model. | |||
| input_nodes (dict): Input nodes to generate onnx model. | |||
| opset_version (int): Op set version of onnx. | |||
| """ | |||
| output_queue = mp.Queue() | |||
| process = mp.Process(target=PyTorchGraphParser._pytorch_graph_to_proto, | |||
| args=(output_queue, model_path, sample_shape, opset_version)) | |||
| args=(output_queue, model_path, input_nodes, opset_version)) | |||
| process.start() | |||
| proto = output_queue.get() | |||
| process.join() | |||
| @@ -81,26 +80,29 @@ class PyTorchGraphParser(GraphParser): | |||
| return onnx_model | |||
| @staticmethod | |||
| def _pytorch_graph_to_proto(output_queue, model_path, sample_shape, opset_version): | |||
| def _pytorch_graph_to_proto(output_queue, model_path, input_nodes, opset_version): | |||
| """ | |||
| Convert pytorch graph to pytorch proto. | |||
| Args: | |||
| output_queue (Queue): Output queue from multi-processing. | |||
| model_path (str): Path to the Pytorch model. | |||
| sample_shape (tuple): Input shape to generate onnx model. | |||
| input_nodes (dict): Input nodes to generate onnx model. | |||
| opset_version (int): Op set version of onnx. | |||
| """ | |||
| try: | |||
| torch = import_module('torch') | |||
| has_cuda = torch.cuda.is_available() | |||
| dump_inputs = dict() | |||
| if has_cuda: | |||
| model = torch.load(f=model_path).cuda() | |||
| dump_input = torch.randn(*sample_shape, device='cuda') | |||
| for node_name, node_shape in input_nodes.items(): | |||
| dump_inputs[node_name] = torch.randn(*node_shape, device='cuda') | |||
| else: | |||
| model = torch.load(f=model_path, map_location="cpu") | |||
| dump_input = torch.randn(*sample_shape, device='cpu') | |||
| for node_name, node_shape in input_nodes.items(): | |||
| dump_inputs[node_name] = torch.randn(*node_shape, device='cpu') | |||
| if isinstance(model, torch.nn.DataParallel): | |||
| raise ValueError('torch.nn.DataParallel is not supported by ONNX exporter.') | |||
| @@ -121,7 +123,8 @@ class PyTorchGraphParser(GraphParser): | |||
| set_opset_version(opset_version) | |||
| set_operator_export_type(operator_export_type) | |||
| graph, params_dict, _ = model_to_graph(model, dump_input, _retain_param_name=True) | |||
| graph, params_dict, _ = model_to_graph(model, args=tuple(dump_inputs.values()), | |||
| input_names=list(dump_inputs.keys()), _retain_param_name=True) | |||
| export_onnx = getattr(graph, '_export_onnx') | |||
| proto, _ = export_onnx( | |||
| params_dict, opset_version, dict(), False, | |||