Merge pull request !769 from moran/network_validationtags/v1.1.0
| @@ -25,6 +25,7 @@ class ConverterErrors(ScriptConverterErrors): | |||
| SCRIPT_NOT_SUPPORT = 1 | |||
| NODE_TYPE_NOT_SUPPORT = 2 | |||
| CODE_SYNTAX_ERROR = 3 | |||
| NODE_INPUT_TYPE_NOT_SUPPORT = 4 | |||
| class ScriptNotSupport(MindInsightException): | |||
| @@ -52,3 +53,12 @@ class CodeSyntaxError(MindInsightException): | |||
| super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR, | |||
| msg, | |||
| http_code=400) | |||
| class NodeInputTypeNotSupport(MindInsightException): | |||
| """The node input type NOT support error.""" | |||
| def __init__(self, msg): | |||
| super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | |||
| msg, | |||
| http_code=400) | |||
| @@ -40,3 +40,9 @@ class NodeType(Enum): | |||
| CLASS = "class" | |||
| FUNC = "func" | |||
| INPUTS = "DataInput" | |||
| @unique | |||
| class InputType(Enum): | |||
| TENSOR = "tensor" | |||
| LIST = "list" | |||
| @@ -633,6 +633,7 @@ class HierarchicalTree(Tree): | |||
| """ | |||
| # All args and value pair in current node module. | |||
| module_args = dict() | |||
| module_settings = dict() | |||
| module_key = self.hash_key(node) | |||
| created = False | |||
| @@ -658,6 +659,7 @@ class HierarchicalTree(Tree): | |||
| nd_inst.data.param_transform(mapper) | |||
| module_args.update(nd_inst.data.args_in_code) | |||
| module_settings.update(nd_inst.data.settings_in_code) | |||
| if not created: | |||
| self._module_vars[module_key].append(nd_inst.data.variable_name) | |||
| @@ -35,6 +35,7 @@ with open(OPERATION_TABLE) as file: | |||
| GET_OP_NAME = "_operation_name_in_ms" | |||
| GET_OP_PARAMS = "_convert_params" | |||
| GET_OP_WEIGHTS = "_convert_trained_weights" | |||
| GET_OP_SETTINGS = "_convert_settings" | |||
| class Mapper(metaclass=abc.ABCMeta): | |||
| @@ -47,14 +48,19 @@ class Mapper(metaclass=abc.ABCMeta): | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _convert_params(params, weights): | |||
| def _convert_params(**kwargs): | |||
| """Convert third party operation's param into MindSpore operation.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _convert_trained_weights(weights): | |||
| def _convert_trained_weights(**kwargs): | |||
| """Convert third party operation's weights into MindSpore operation.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _convert_settings(**kwargs): | |||
| """Convert third party operation's params into MindSpore OP operator.""" | |||
| @classmethod | |||
| @abc.abstractmethod | |||
| def convert(cls, op_name: str, params: Dict, weights: Dict = None): | |||
| @@ -75,13 +81,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| weights (dict): Weights in onnx. | |||
| Returns: | |||
| Tuple[str, dict], operation name and params. | |||
| Tuple[str, dict, dict], operation name and params and settings. | |||
| """ | |||
| global TABLE | |||
| module_name = TABLE.get(op_name) | |||
| if not module_name: | |||
| return None, dict() | |||
| return None, dict(), dict() | |||
| pos = module_name.rfind(".") | |||
| try: | |||
| @@ -90,32 +96,38 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| op_name_converter = getattr(converter, GET_OP_NAME) | |||
| params_converter = getattr(converter, GET_OP_PARAMS) | |||
| weights_converter = getattr(converter, GET_OP_WEIGHTS) | |||
| settings_converter = getattr(converter, GET_OP_SETTINGS) | |||
| except (ModuleNotFoundError,) as e: | |||
| # If mapper can not be found, then skip it. | |||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||
| log.error(err_msg) | |||
| return None, dict() | |||
| return None, dict(), dict() | |||
| try: | |||
| converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) | |||
| converted_params = params_converter(params, weights) | |||
| converted_weights = weights_converter(weights) if weights else dict() | |||
| converted_params = params_converter(params=params, weights=weights) | |||
| converted_weights = weights_converter(weights=weights) if weights else dict() | |||
| converted_params.update(converted_weights) | |||
| converted_settings = settings_converter(params=params) | |||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | |||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||
| log.error(err_msg) | |||
| return None, dict() | |||
| return None, dict(), dict() | |||
| return converter_name, converted_params | |||
| return converter_name, converted_params, converted_settings | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| def _convert_params(**kwargs): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| def _convert_settings(**kwargs): | |||
| raise NotImplementedError | |||
| @@ -25,7 +25,8 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||
| return f"nn.BatchNorm{dim}d" | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| return { | |||
| 'num_features': params['output_shape'][1], | |||
| 'eps': params['epsilon'], | |||
| @@ -33,7 +34,9 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||
| } | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -27,7 +27,10 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| return f"nn.Conv{dim}d" | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| def _convert_params(**kwargs): | |||
| weights = kwargs['weights'] | |||
| params = kwargs['params'] | |||
| weight = weights['weight'].numpy() | |||
| weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | |||
| if isinstance(params['dilations'], list): | |||
| @@ -46,7 +49,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| kernel_size = kernel_size[0] | |||
| else: | |||
| kernel_size = tuple(kernel_size) | |||
| pad_mode, padding = ConvMapper._convert_padding(params) | |||
| pad_mode, padding = ConvMapper._convert_padding(params=params) | |||
| return { | |||
| 'in_channels': in_channels, | |||
| 'out_channels': out_channels, | |||
| @@ -58,13 +61,13 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| 'group': params['group']} | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_padding(params): | |||
| def _convert_padding(**kwargs): | |||
| """Convert padding.""" | |||
| params = kwargs['params'] | |||
| if sum(params['pads']) == 0: | |||
| return '\"valid\"', 0 | |||
| pads_onnx = params['pads'] | |||
| @@ -73,3 +76,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| for num_begin, num_end in zip(pads_onnx[:half_index], pads_onnx[half_index:]): | |||
| padding += [num_begin, num_end] | |||
| return '\"pad\"', tuple(padding) | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -24,7 +24,8 @@ class DenseMapper(ONNXToMindSporeMapper): | |||
| return "nn.Dense" | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| def _convert_params(**kwargs): | |||
| weights = kwargs['weights'] | |||
| has_bias = bool('bias' in weights) | |||
| weight = weights['weight'].numpy().transpose() | |||
| in_channels, out_channels = weight.shape | |||
| @@ -35,7 +36,9 @@ class DenseMapper(ONNXToMindSporeMapper): | |||
| } | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -24,13 +24,13 @@ class FlattenMapper(ONNXToMindSporeMapper): | |||
| return "nn.Flatten" | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| if params: | |||
| pass | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -30,7 +30,8 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||
| return op_name.format(dim) | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| dim = 1 if len(params['input_shape']) == 3 else 2 | |||
| if dim == 1: | |||
| kernel_size = params['input_shape'][-1] // params['output_shape'][-1] | |||
| @@ -43,7 +44,9 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||
| } | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -24,17 +24,18 @@ class PadMapper(ONNXToMindSporeMapper): | |||
| return "nn.Pad" | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| if params['mode'] == 'constant': | |||
| if params['value'] == 0: | |||
| mode = '\"CONSTANT\"' | |||
| else: | |||
| msg = f"[NOT support value is NOT 0]\"CONSTANT\"" | |||
| msg = "{UNSUPPORTED: value is NOT 0}\"CONSTANT\"" | |||
| mode = msg | |||
| elif params['mode'] == 'reflect': | |||
| mode = '\"REFLECT\"' | |||
| else: | |||
| msg = f"[NOT support {params['mode']}]\"UNKNOWN\"" | |||
| msg = f"{{UNSUPPORTED: \"{params['mode']}\"}}\"UNKNOWN\"" | |||
| mode = msg | |||
| pads_onnx = params['pads'] | |||
| half_index = len(pads_onnx) // 2 | |||
| @@ -44,7 +45,9 @@ class PadMapper(ONNXToMindSporeMapper): | |||
| 'mode': mode} | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -29,7 +29,8 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| return op_name.format(dim) | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| transformed_params = dict() | |||
| transformed_params["kernel_size"] = tuple(params['kernel_shape']) | |||
| transformed_params["stride"] = tuple(params['strides']) | |||
| @@ -43,7 +44,9 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| return transformed_params | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -21,16 +21,28 @@ class ReLUMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "nn.ReLU" | |||
| if not kwargs.get('params'): | |||
| name = "nn.ReLU" | |||
| else: | |||
| params = kwargs['params'] | |||
| max_clip = params['max'] if params.get('max') else 0 | |||
| min_clip = params['min'] if params.get('min') else 0 | |||
| if max_clip == 6 and min_clip == 0: | |||
| name = "nn.ReLU6" | |||
| elif max_clip == min_clip == 0: | |||
| name = "nn.ReLU" | |||
| else: | |||
| name = None | |||
| return name | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| if params: | |||
| pass | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -24,13 +24,13 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| return "P.TensorAdd" | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| if params: | |||
| pass | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | |||
| from ...base import ONNXToMindSporeMapper | |||
| class ConcatMapper(ONNXToMindSporeMapper): | |||
| """Concat mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Concat" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| return {'axis': params['axis']} | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| input_type = InputType.LIST.value | |||
| return {'input_type': input_type} | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class ReduceMeanMapper(ONNXToMindSporeMapper): | |||
| """ReduceMean mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.ReduceMean" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| keep_dims = not params['keepdims'] == 0 | |||
| return {'keep_dims': keep_dims} | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| params = kwargs['params'] | |||
| if params.get('axes'): | |||
| axis = params['axes'][0] if len(params['axes']) == 1 else tuple(params['axes']) | |||
| else: | |||
| axis = tuple() | |||
| return {'values': {'axis': axis}} | |||
| @@ -8,5 +8,8 @@ | |||
| "onnx::GlobalAveragePool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.global_pool_mapper.GlobalPoolMapper", | |||
| "onnx::Flatten": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.flatten_mapper.FlattenMapper", | |||
| "onnx::Add": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.add_mapper.AddMapper", | |||
| "onnx::Pad": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.pad_mapper.PadMapper" | |||
| "onnx::Pad": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.pad_mapper.PadMapper", | |||
| "onnx::ReduceMean": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reduce_mean_mapper.ReduceMeanMapper", | |||
| "onnx::Concat": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.concat_mapper.ConcatMapper", | |||
| "onnx::Clip": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper" | |||
| } | |||
| @@ -96,6 +96,21 @@ class ReportGenerator(metaclass=abc.ABCMeta): | |||
| f"[UnConvert] '{unconverted_operator_name}' didn't convert." | |||
| return content | |||
| @staticmethod | |||
| def _get_unsupported_params(num_line, code_line): | |||
| """Get unsupported params in converted operator.""" | |||
| if 'UNSUPPORTED' in code_line: | |||
| unsupported_params = re.findall(r"(.*).*[=][{]SUPPORTED", code_line) | |||
| unsupported_msg = re.findall(r".*UNSUPPORTED[:] (.*)[}]", code_line) | |||
| location = [f"{num_line + 1}", f"{code_line.index('UNSUPPORTED') + 1}"] | |||
| unsupported_params_info = \ | |||
| f"line {':'.join(location)}: " \ | |||
| f"[Unsupported params] {unsupported_params}: {unsupported_msg}." | |||
| else: | |||
| unsupported_params_info = None | |||
| return unsupported_params_info | |||
| def gen_report(self, code: str): | |||
| """ | |||
| Generate report. | |||
| @@ -111,6 +126,7 @@ class ReportGenerator(metaclass=abc.ABCMeta): | |||
| num_unconverted_operator = 0 | |||
| num_converted_operator = 0 | |||
| converted_operator = None | |||
| self._content = self._extra['start'] | |||
| for num_line in range(0, num_all_lines): | |||
| code_line = code_lines[num_line] | |||
| @@ -132,6 +148,14 @@ class ReportGenerator(metaclass=abc.ABCMeta): | |||
| if converted_operator: | |||
| num_converted_operator += 1 | |||
| info_unsupported_params = self._get_unsupported_params( | |||
| num_line, | |||
| code_line | |||
| ) | |||
| if info_unsupported_params: | |||
| self._content = f"{NEW_LINE}".join((self._content, | |||
| info_unsupported_params)) | |||
| self._content = f"{NEW_LINE}".join((self._content, self._extra['end'])) | |||
| converted_rate = \ | |||
| @@ -289,12 +289,16 @@ class GraphNode(abc.ABC): | |||
| self._op_in_ms = None | |||
| # Params in mindspore. | |||
| self._params_in_ms = dict() | |||
| # Settings in mindspore. | |||
| self._settings_in_ms = dict() | |||
| # Node type of current node, e.g. class, module, operation. | |||
| self._node_type = None | |||
| # Tag name on tree. | |||
| self._tag_on_tree = None | |||
| # Function, class or operation needed args. | |||
| self._args_in_code = dict() | |||
| # Operation needed settings. | |||
| self._settings_in_code = dict() | |||
| # Variable name declared in init block. | |||
| self._variable_name = None | |||
| # Output variable name declared in construct block. | |||
| @@ -364,6 +368,27 @@ class GraphNode(abc.ABC): | |||
| """ | |||
| self._args_in_code = args | |||
| @property | |||
| def settings_in_code(self): | |||
| """ | |||
| Settings in code. | |||
| Returns: | |||
| dict, settings. | |||
| """ | |||
| return self._settings_in_code | |||
| @settings_in_code.setter | |||
| def settings_in_code(self, settings): | |||
| """ | |||
| Settings in code. | |||
| Args: | |||
| settings(dict): Settings. | |||
| """ | |||
| self._settings_in_code = settings | |||
| @property | |||
| def input_shape(self): | |||
| """ | |||
| @@ -567,14 +592,16 @@ class GraphNode(abc.ABC): | |||
| params.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| op_name_in_mindspore, ms_params = mapper.convert(op_name=self.op_name, | |||
| params=params, | |||
| weights=self._weight) | |||
| op_name_in_mindspore, ms_params, ms_settings = mapper.convert(op_name=self.op_name, | |||
| params=params, | |||
| weights=self._weight) | |||
| if op_name_in_mindspore: | |||
| self._op_in_ms = op_name_in_mindspore | |||
| self._params_in_ms = ms_params | |||
| self._settings_in_ms = ms_settings | |||
| else: | |||
| self._op_in_ms = self._op_name | |||
| self._params_in_ms = self._op_params | |||
| self._settings_in_ms = dict() | |||
| return self._op_in_ms, self._params_in_ms | |||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||
| @@ -18,8 +18,9 @@ from copy import deepcopy | |||
| from .base import GraphNode | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||
| SEPARATOR_IN_ONNX_OP | |||
| SEPARATOR_IN_ONNX_OP, InputType | |||
| from ..mapper.base import Mapper | |||
| from ...common.exceptions import NodeInputTypeNotSupport | |||
| class PyTorchGraphNode(GraphNode): | |||
| @@ -186,6 +187,8 @@ class PyTorchGraphNode(GraphNode): | |||
| self._opt_var_name = output_var | |||
| args = self.args_in_code | |||
| settings = self.settings_in_code | |||
| if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): | |||
| args.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| @@ -193,16 +196,49 @@ class PyTorchGraphNode(GraphNode): | |||
| if self._node_type == NodeType.OPERATION.value: | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(ipt_args_in_construct, | |||
| settings) | |||
| else: | |||
| # When it's type is module, class or func, | |||
| # it's not necessary to replace var. | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| declare = f"self.{self._variable_name} = {operator}({expr})" | |||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_in_construct})" | |||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | |||
| return declare, call | |||
| @staticmethod | |||
| def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): | |||
| """ | |||
| Generate input with args and settings in construct. | |||
| Args: | |||
| ipt_args_in_construct(str): input args in construct. | |||
| settings(dict): settings in operator. | |||
| """ | |||
| if settings.get('input_type'): | |||
| input_type = settings['input_type'] | |||
| if input_type == InputType.TENSOR.value: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| elif input_type == InputType.LIST.value: | |||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct})" | |||
| else: | |||
| raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") | |||
| else: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| if settings.get('values'): | |||
| settings_value = settings['values'] | |||
| if settings_value: | |||
| settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) | |||
| ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) | |||
| return ipt_args_settings_in_construct | |||
| def to_ir(self): | |||
| """ | |||
| No need to implement for now. | |||
| @@ -266,17 +302,20 @@ class PyTorchGraphNode(GraphNode): | |||
| self._args_in_code = dict() | |||
| for arg, value in args.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| return None, None | |||
| return None, None, None | |||
| if not self.transformed: | |||
| _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | |||
| _, _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | |||
| for arg, value in self._params_in_ms.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| for arg, value in self._settings_in_ms.items(): | |||
| self._settings_in_code[arg] = value | |||
| self.transformed = True | |||
| return self._op_in_ms, self._params_in_ms | |||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||
| """ | |||
| @@ -144,6 +144,18 @@ class TestHierarchicalTree: | |||
| if report_folder: | |||
| shutil.rmtree(report_folder) | |||
| @staticmethod | |||
| def _create_node(key, val, weight, input_shape, output_shape): | |||
| """Create node.""" | |||
| node = PyTorchGraphNode(weight=weight) | |||
| node.add_input_and_output_shape(input_shape, output_shape) | |||
| node.tag = key.split('/')[-1] if len(key.split('/')) > 1 else key | |||
| node.op_name = val['op_name'] if val.get('op_name') else None | |||
| node.precursor_nodes = val['precursor_nodes'] if val.get('precursor_nodes') else [] | |||
| node.successor_nodes = val['successor_nodes'] if val.get('successor_nodes') else [] | |||
| node.node_type = val['node_type'] if val.get('node_type') else None | |||
| return node | |||
| @staticmethod | |||
| def _create_tree(get_raw_params, params): | |||
| """Create tree.""" | |||
| @@ -154,13 +166,7 @@ class TestHierarchicalTree: | |||
| get_raw_params.return_value = val['op_params'] if val.get('op_params') else dict() | |||
| weight = val['weight'] if val.get('weight') else None | |||
| node = PyTorchGraphNode(weight=weight) | |||
| node.add_input_and_output_shape(input_shape, output_shape) | |||
| node.tag = key.split('/')[-1] if len(key.split('/')) > 1 else key | |||
| node.op_name = val['op_name'] if val.get('op_name') else None | |||
| node.precursor_nodes = val['precursor_nodes'] if val.get('precursor_nodes') else [] | |||
| node.successor_nodes = val['successor_nodes'] if val.get('successor_nodes') else [] | |||
| node.node_type = val['node_type'] if val.get('node_type') else None | |||
| node = TestHierarchicalTree._create_node(key, val, weight, input_shape, output_shape) | |||
| tree.create_node( | |||
| tag=node.tag, | |||
| @@ -13,15 +13,60 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Test all operator mappers on transformation from pytorch to mindspore.""" | |||
| import numpy as np | |||
| import pytest | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from tests.utils import mindspore | |||
| class TestMappers: | |||
| """Test Mappers.""" | |||
| @pytest.mark.parametrize('params', [{ | |||
| 'input': {'op_name': 'onnx::Conv', | |||
| 'params': {'dilations': [1, 1], | |||
| 'group': 1, | |||
| 'pads': [1, 2, 3, 4], | |||
| 'strides': [1, 1]}, | |||
| 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 1, 1], dtype=np.int32))}}, | |||
| 'expected_output': {'converter_name': 'nn.Conv2d', | |||
| 'converted_params': {'in_channels': 3, | |||
| 'out_channels': 64, | |||
| 'kernel_size': (1, 1), | |||
| 'stride': (1, 1), | |||
| 'padding': (1, 3, 2, 4), | |||
| 'pad_mode': '\"pad\"', | |||
| 'dilation': (1, 1), | |||
| 'group': 1}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Conv', | |||
| 'params': {'dilations': [1, 1], | |||
| 'group': 1, | |||
| 'pads': [0, 0, 0, 0], | |||
| 'strides': [1, 1]}, | |||
| 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 2, 2], dtype=np.int32))}}, | |||
| 'expected_output': {'converter_name': 'nn.Conv2d', | |||
| 'converted_params': {'in_channels': 3, | |||
| 'out_channels': 64, | |||
| 'kernel_size': (2, 2), | |||
| 'stride': (1, 1), | |||
| 'padding': 0, | |||
| 'pad_mode': '\"valid\"', | |||
| 'dilation': (1, 1), | |||
| 'group': 1}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Gemm', | |||
| 'params': dict(), | |||
| 'weights': {'weight': mindspore.Tensor(np.zeros([10, 3], dtype=np.int32)), | |||
| 'bias': mindspore.Tensor(np.zeros([10, 1], dtype=np.int32))}}, | |||
| 'expected_output': {'converter_name': 'nn.Dense', | |||
| 'converted_params': {'in_channels': 3, | |||
| 'out_channels': 10, | |||
| 'has_bias': True}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::BatchNormalization', | |||
| 'params': {'epsilon': 1e-5, | |||
| 'momentum': 0.9, | |||
| @@ -30,13 +75,15 @@ class TestMappers: | |||
| 'expected_output': {'converter_name': 'nn.BatchNorm2d', | |||
| 'converted_params': {'num_features': 6, | |||
| 'eps': 1e-5, | |||
| 'momentum': 0.9}} | |||
| 'momentum': 0.9}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Relu', | |||
| 'params': dict(), | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.ReLU', | |||
| 'converted_params': dict()} | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::MaxPool', | |||
| 'params': {'kernel_shape': [3, 3], | |||
| @@ -46,7 +93,8 @@ class TestMappers: | |||
| 'expected_output': {'converter_name': 'nn.MaxPool2d', | |||
| 'converted_params': {'kernel_size': (3, 3), | |||
| 'stride': (2, 2), | |||
| 'pad_mode': '"same"'}} | |||
| 'pad_mode': '"same"'}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::AveragePool', | |||
| 'params': {'kernel_shape': [3, 3], | |||
| @@ -56,31 +104,120 @@ class TestMappers: | |||
| 'expected_output': {'converter_name': 'nn.AvgPool2d', | |||
| 'converted_params': {'kernel_size': (3, 3), | |||
| 'stride': (2, 2), | |||
| 'pad_mode': '"same"'}} | |||
| 'pad_mode': '"same"'}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::GlobalAveragePool', | |||
| 'params': {'input_shape': (1, 3, 10, 10), | |||
| 'output_shape': (1, 3, 1, 1)}, | |||
| 'weights': ''}, | |||
| 'expected_output': {'converter_name': 'nn.AvgPool2d', | |||
| 'converted_params': {'kernel_size': (10, 10)}} | |||
| 'converted_params': {'kernel_size': (10, 10)}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Flatten', | |||
| 'params': dict(), | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.Flatten', | |||
| 'converted_params': dict()} | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Add', | |||
| 'params': dict(), | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'P.TensorAdd', | |||
| 'converted_params': dict()} | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Pad', | |||
| 'params': {'pads': [0, 1, 2, 3], | |||
| 'value': 0, | |||
| 'mode': 'constant'}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.Pad', | |||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||
| 'mode': '\"CONSTANT\"'}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Pad', | |||
| 'params': {'pads': [0, 1, 2, 3], | |||
| 'mode': 'reflect'}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.Pad', | |||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||
| 'mode': '\"REFLECT\"'}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Pad', | |||
| 'params': {'pads': [0, 1, 2, 3], | |||
| 'value': 1, | |||
| 'mode': 'constant'}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.Pad', | |||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||
| 'mode': '{UNSUPPORTED: value is NOT 0}\"CONSTANT\"'}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Pad', | |||
| 'params': {'pads': [0, 1, 2, 3], | |||
| 'mode': 'edge'}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.Pad', | |||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||
| 'mode': '{UNSUPPORTED: \"edge\"}\"UNKNOWN\"'}, | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::ReduceMean', | |||
| 'params': {'keepdims': 0, | |||
| 'axes': [1, 2]}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'P.ReduceMean', | |||
| 'converted_params': {'keep_dims': False}, | |||
| 'converted_settings': {'values': {'axis': (1, 2)}}} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::ReduceMean', | |||
| 'params': {'keepdims': 1, | |||
| 'axes': [1]}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'P.ReduceMean', | |||
| 'converted_params': {'keep_dims': True}, | |||
| 'converted_settings': {'values': {'axis': 1}}} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Concat', | |||
| 'params': {'axis': 0}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'P.Concat', | |||
| 'converted_params': {'axis': 0}, | |||
| 'converted_settings': {'input_type': "list"}} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Clip', | |||
| 'params': {'max': 6, | |||
| 'min': 0}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.ReLU6', | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Clip', | |||
| 'params': dict(), | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.ReLU', | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Clip', | |||
| 'params': {'max': 3, | |||
| 'min': 2}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': None, | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| }]) | |||
| def test_mapper(self, params): | |||
| """Test mapper function.""" | |||
| mapper = ONNXToMindSporeMapper() | |||
| converter_name, converted_params = \ | |||
| converter_name, converted_params, converted_settings = \ | |||
| mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) | |||
| assert params['expected_output']['converter_name'] == converter_name | |||
| assert params['expected_output']['converted_params'] == converted_params | |||
| assert params['expected_output']['converted_settings'] == converted_settings | |||
| @@ -28,3 +28,7 @@ class Tensor: | |||
| def __repr__(self): | |||
| return str(self.asnumpy()) | |||
| def numpy(self): | |||
| """Get value in numpy format, the torch format.""" | |||
| return np.array(self._value) | |||