Merge pull request !769 from moran/network_validationtags/v1.1.0
| @@ -25,6 +25,7 @@ class ConverterErrors(ScriptConverterErrors): | |||||
| SCRIPT_NOT_SUPPORT = 1 | SCRIPT_NOT_SUPPORT = 1 | ||||
| NODE_TYPE_NOT_SUPPORT = 2 | NODE_TYPE_NOT_SUPPORT = 2 | ||||
| CODE_SYNTAX_ERROR = 3 | CODE_SYNTAX_ERROR = 3 | ||||
| NODE_INPUT_TYPE_NOT_SUPPORT = 4 | |||||
| class ScriptNotSupport(MindInsightException): | class ScriptNotSupport(MindInsightException): | ||||
| @@ -52,3 +53,12 @@ class CodeSyntaxError(MindInsightException): | |||||
| super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR, | super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR, | ||||
| msg, | msg, | ||||
| http_code=400) | http_code=400) | ||||
| class NodeInputTypeNotSupport(MindInsightException): | |||||
| """The node input type NOT support error.""" | |||||
| def __init__(self, msg): | |||||
| super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | |||||
| msg, | |||||
| http_code=400) | |||||
| @@ -40,3 +40,9 @@ class NodeType(Enum): | |||||
| CLASS = "class" | CLASS = "class" | ||||
| FUNC = "func" | FUNC = "func" | ||||
| INPUTS = "DataInput" | INPUTS = "DataInput" | ||||
| @unique | |||||
| class InputType(Enum): | |||||
| TENSOR = "tensor" | |||||
| LIST = "list" | |||||
| @@ -633,6 +633,7 @@ class HierarchicalTree(Tree): | |||||
| """ | """ | ||||
| # All args and value pair in current node module. | # All args and value pair in current node module. | ||||
| module_args = dict() | module_args = dict() | ||||
| module_settings = dict() | |||||
| module_key = self.hash_key(node) | module_key = self.hash_key(node) | ||||
| created = False | created = False | ||||
| @@ -658,6 +659,7 @@ class HierarchicalTree(Tree): | |||||
| nd_inst.data.param_transform(mapper) | nd_inst.data.param_transform(mapper) | ||||
| module_args.update(nd_inst.data.args_in_code) | module_args.update(nd_inst.data.args_in_code) | ||||
| module_settings.update(nd_inst.data.settings_in_code) | |||||
| if not created: | if not created: | ||||
| self._module_vars[module_key].append(nd_inst.data.variable_name) | self._module_vars[module_key].append(nd_inst.data.variable_name) | ||||
| @@ -35,6 +35,7 @@ with open(OPERATION_TABLE) as file: | |||||
| GET_OP_NAME = "_operation_name_in_ms" | GET_OP_NAME = "_operation_name_in_ms" | ||||
| GET_OP_PARAMS = "_convert_params" | GET_OP_PARAMS = "_convert_params" | ||||
| GET_OP_WEIGHTS = "_convert_trained_weights" | GET_OP_WEIGHTS = "_convert_trained_weights" | ||||
| GET_OP_SETTINGS = "_convert_settings" | |||||
| class Mapper(metaclass=abc.ABCMeta): | class Mapper(metaclass=abc.ABCMeta): | ||||
| @@ -47,14 +48,19 @@ class Mapper(metaclass=abc.ABCMeta): | |||||
| @staticmethod | @staticmethod | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def _convert_params(params, weights): | |||||
| def _convert_params(**kwargs): | |||||
| """Convert third party operation's param into MindSpore operation.""" | """Convert third party operation's param into MindSpore operation.""" | ||||
| @staticmethod | @staticmethod | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def _convert_trained_weights(weights): | |||||
| def _convert_trained_weights(**kwargs): | |||||
| """Convert third party operation's weights into MindSpore operation.""" | """Convert third party operation's weights into MindSpore operation.""" | ||||
| @staticmethod | |||||
| @abc.abstractmethod | |||||
| def _convert_settings(**kwargs): | |||||
| """Convert third party operation's params into MindSpore OP operator.""" | |||||
| @classmethod | @classmethod | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def convert(cls, op_name: str, params: Dict, weights: Dict = None): | def convert(cls, op_name: str, params: Dict, weights: Dict = None): | ||||
| @@ -75,13 +81,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| weights (dict): Weights in onnx. | weights (dict): Weights in onnx. | ||||
| Returns: | Returns: | ||||
| Tuple[str, dict], operation name and params. | |||||
| Tuple[str, dict, dict], operation name and params and settings. | |||||
| """ | """ | ||||
| global TABLE | global TABLE | ||||
| module_name = TABLE.get(op_name) | module_name = TABLE.get(op_name) | ||||
| if not module_name: | if not module_name: | ||||
| return None, dict() | |||||
| return None, dict(), dict() | |||||
| pos = module_name.rfind(".") | pos = module_name.rfind(".") | ||||
| try: | try: | ||||
| @@ -90,32 +96,38 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| op_name_converter = getattr(converter, GET_OP_NAME) | op_name_converter = getattr(converter, GET_OP_NAME) | ||||
| params_converter = getattr(converter, GET_OP_PARAMS) | params_converter = getattr(converter, GET_OP_PARAMS) | ||||
| weights_converter = getattr(converter, GET_OP_WEIGHTS) | weights_converter = getattr(converter, GET_OP_WEIGHTS) | ||||
| settings_converter = getattr(converter, GET_OP_SETTINGS) | |||||
| except (ModuleNotFoundError,) as e: | except (ModuleNotFoundError,) as e: | ||||
| # If mapper can not be found, then skip it. | # If mapper can not be found, then skip it. | ||||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | err_msg = f"Converting {op_name} failed, see {str(e)}" | ||||
| log.error(err_msg) | log.error(err_msg) | ||||
| return None, dict() | |||||
| return None, dict(), dict() | |||||
| try: | try: | ||||
| converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) | converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) | ||||
| converted_params = params_converter(params, weights) | |||||
| converted_weights = weights_converter(weights) if weights else dict() | |||||
| converted_params = params_converter(params=params, weights=weights) | |||||
| converted_weights = weights_converter(weights=weights) if weights else dict() | |||||
| converted_params.update(converted_weights) | converted_params.update(converted_weights) | ||||
| converted_settings = settings_converter(params=params) | |||||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | ||||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | err_msg = f"Converting {op_name} failed, see {str(e)}" | ||||
| log.error(err_msg) | log.error(err_msg) | ||||
| return None, dict() | |||||
| return None, dict(), dict() | |||||
| return converter_name, converted_params | |||||
| return converter_name, converted_params, converted_settings | |||||
| @staticmethod | @staticmethod | ||||
| def _operation_name_in_ms(*args, **kwargs): | def _operation_name_in_ms(*args, **kwargs): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| def _convert_params(**kwargs): | |||||
| raise NotImplementedError | |||||
| @staticmethod | |||||
| def _convert_trained_weights(**kwargs): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| def _convert_settings(**kwargs): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -25,7 +25,8 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||||
| return f"nn.BatchNorm{dim}d" | return f"nn.BatchNorm{dim}d" | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| def _convert_params(**kwargs): | |||||
| params = kwargs['params'] | |||||
| return { | return { | ||||
| 'num_features': params['output_shape'][1], | 'num_features': params['output_shape'][1], | ||||
| 'eps': params['epsilon'], | 'eps': params['epsilon'], | ||||
| @@ -33,7 +34,9 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||||
| } | } | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| if weights: | |||||
| pass | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return dict() | return dict() | ||||
| @@ -27,7 +27,10 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| return f"nn.Conv{dim}d" | return f"nn.Conv{dim}d" | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| def _convert_params(**kwargs): | |||||
| weights = kwargs['weights'] | |||||
| params = kwargs['params'] | |||||
| weight = weights['weight'].numpy() | weight = weights['weight'].numpy() | ||||
| weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | ||||
| if isinstance(params['dilations'], list): | if isinstance(params['dilations'], list): | ||||
| @@ -46,7 +49,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| kernel_size = kernel_size[0] | kernel_size = kernel_size[0] | ||||
| else: | else: | ||||
| kernel_size = tuple(kernel_size) | kernel_size = tuple(kernel_size) | ||||
| pad_mode, padding = ConvMapper._convert_padding(params) | |||||
| pad_mode, padding = ConvMapper._convert_padding(params=params) | |||||
| return { | return { | ||||
| 'in_channels': in_channels, | 'in_channels': in_channels, | ||||
| 'out_channels': out_channels, | 'out_channels': out_channels, | ||||
| @@ -58,13 +61,13 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| 'group': params['group']} | 'group': params['group']} | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| if weights: | |||||
| pass | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | return dict() | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_padding(params): | |||||
| def _convert_padding(**kwargs): | |||||
| """Convert padding.""" | |||||
| params = kwargs['params'] | |||||
| if sum(params['pads']) == 0: | if sum(params['pads']) == 0: | ||||
| return '\"valid\"', 0 | return '\"valid\"', 0 | ||||
| pads_onnx = params['pads'] | pads_onnx = params['pads'] | ||||
| @@ -73,3 +76,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| for num_begin, num_end in zip(pads_onnx[:half_index], pads_onnx[half_index:]): | for num_begin, num_end in zip(pads_onnx[:half_index], pads_onnx[half_index:]): | ||||
| padding += [num_begin, num_end] | padding += [num_begin, num_end] | ||||
| return '\"pad\"', tuple(padding) | return '\"pad\"', tuple(padding) | ||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return dict() | |||||
| @@ -24,7 +24,8 @@ class DenseMapper(ONNXToMindSporeMapper): | |||||
| return "nn.Dense" | return "nn.Dense" | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| def _convert_params(**kwargs): | |||||
| weights = kwargs['weights'] | |||||
| has_bias = bool('bias' in weights) | has_bias = bool('bias' in weights) | ||||
| weight = weights['weight'].numpy().transpose() | weight = weights['weight'].numpy().transpose() | ||||
| in_channels, out_channels = weight.shape | in_channels, out_channels = weight.shape | ||||
| @@ -35,7 +36,9 @@ class DenseMapper(ONNXToMindSporeMapper): | |||||
| } | } | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| if weights: | |||||
| pass | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return dict() | return dict() | ||||
| @@ -24,13 +24,13 @@ class FlattenMapper(ONNXToMindSporeMapper): | |||||
| return "nn.Flatten" | return "nn.Flatten" | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| if params: | |||||
| pass | |||||
| def _convert_params(**kwargs): | |||||
| return dict() | return dict() | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| if weights: | |||||
| pass | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return dict() | return dict() | ||||
| @@ -30,7 +30,8 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||||
| return op_name.format(dim) | return op_name.format(dim) | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| def _convert_params(**kwargs): | |||||
| params = kwargs['params'] | |||||
| dim = 1 if len(params['input_shape']) == 3 else 2 | dim = 1 if len(params['input_shape']) == 3 else 2 | ||||
| if dim == 1: | if dim == 1: | ||||
| kernel_size = params['input_shape'][-1] // params['output_shape'][-1] | kernel_size = params['input_shape'][-1] // params['output_shape'][-1] | ||||
| @@ -43,7 +44,9 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||||
| } | } | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| if weights: | |||||
| pass | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return dict() | return dict() | ||||
| @@ -24,17 +24,18 @@ class PadMapper(ONNXToMindSporeMapper): | |||||
| return "nn.Pad" | return "nn.Pad" | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| def _convert_params(**kwargs): | |||||
| params = kwargs['params'] | |||||
| if params['mode'] == 'constant': | if params['mode'] == 'constant': | ||||
| if params['value'] == 0: | if params['value'] == 0: | ||||
| mode = '\"CONSTANT\"' | mode = '\"CONSTANT\"' | ||||
| else: | else: | ||||
| msg = f"[NOT support value is NOT 0]\"CONSTANT\"" | |||||
| msg = "{UNSUPPORTED: value is NOT 0}\"CONSTANT\"" | |||||
| mode = msg | mode = msg | ||||
| elif params['mode'] == 'reflect': | elif params['mode'] == 'reflect': | ||||
| mode = '\"REFLECT\"' | mode = '\"REFLECT\"' | ||||
| else: | else: | ||||
| msg = f"[NOT support {params['mode']}]\"UNKNOWN\"" | |||||
| msg = f"{{UNSUPPORTED: \"{params['mode']}\"}}\"UNKNOWN\"" | |||||
| mode = msg | mode = msg | ||||
| pads_onnx = params['pads'] | pads_onnx = params['pads'] | ||||
| half_index = len(pads_onnx) // 2 | half_index = len(pads_onnx) // 2 | ||||
| @@ -44,7 +45,9 @@ class PadMapper(ONNXToMindSporeMapper): | |||||
| 'mode': mode} | 'mode': mode} | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| if weights: | |||||
| pass | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return dict() | return dict() | ||||
| @@ -29,7 +29,8 @@ class PoolMapper(ONNXToMindSporeMapper): | |||||
| return op_name.format(dim) | return op_name.format(dim) | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| def _convert_params(**kwargs): | |||||
| params = kwargs['params'] | |||||
| transformed_params = dict() | transformed_params = dict() | ||||
| transformed_params["kernel_size"] = tuple(params['kernel_shape']) | transformed_params["kernel_size"] = tuple(params['kernel_shape']) | ||||
| transformed_params["stride"] = tuple(params['strides']) | transformed_params["stride"] = tuple(params['strides']) | ||||
| @@ -43,7 +44,9 @@ class PoolMapper(ONNXToMindSporeMapper): | |||||
| return transformed_params | return transformed_params | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| if weights: | |||||
| pass | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return dict() | return dict() | ||||
| @@ -21,16 +21,28 @@ class ReLUMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _operation_name_in_ms(*args, **kwargs): | def _operation_name_in_ms(*args, **kwargs): | ||||
| return "nn.ReLU" | |||||
| if not kwargs.get('params'): | |||||
| name = "nn.ReLU" | |||||
| else: | |||||
| params = kwargs['params'] | |||||
| max_clip = params['max'] if params.get('max') else 0 | |||||
| min_clip = params['min'] if params.get('min') else 0 | |||||
| if max_clip == 6 and min_clip == 0: | |||||
| name = "nn.ReLU6" | |||||
| elif max_clip == min_clip == 0: | |||||
| name = "nn.ReLU" | |||||
| else: | |||||
| name = None | |||||
| return name | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| if params: | |||||
| pass | |||||
| def _convert_params(**kwargs): | |||||
| return dict() | return dict() | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| if weights: | |||||
| pass | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return dict() | return dict() | ||||
| @@ -24,13 +24,13 @@ class AddMapper(ONNXToMindSporeMapper): | |||||
| return "P.TensorAdd" | return "P.TensorAdd" | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | |||||
| if params: | |||||
| pass | |||||
| def _convert_params(**kwargs): | |||||
| return dict() | return dict() | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(weights): | |||||
| if weights: | |||||
| pass | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return dict() | return dict() | ||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Mapper module.""" | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | |||||
| from ...base import ONNXToMindSporeMapper | |||||
| class ConcatMapper(ONNXToMindSporeMapper): | |||||
| """Concat mapper.""" | |||||
| @staticmethod | |||||
| def _operation_name_in_ms(*args, **kwargs): | |||||
| return "P.Concat" | |||||
| @staticmethod | |||||
| def _convert_params(**kwargs): | |||||
| params = kwargs['params'] | |||||
| return {'axis': params['axis']} | |||||
| @staticmethod | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| input_type = InputType.LIST.value | |||||
| return {'input_type': input_type} | |||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Mapper module.""" | |||||
| from ...base import ONNXToMindSporeMapper | |||||
| class ReduceMeanMapper(ONNXToMindSporeMapper): | |||||
| """ReduceMean mapper.""" | |||||
| @staticmethod | |||||
| def _operation_name_in_ms(*args, **kwargs): | |||||
| return "P.ReduceMean" | |||||
| @staticmethod | |||||
| def _convert_params(**kwargs): | |||||
| params = kwargs['params'] | |||||
| keep_dims = not params['keepdims'] == 0 | |||||
| return {'keep_dims': keep_dims} | |||||
| @staticmethod | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| params = kwargs['params'] | |||||
| if params.get('axes'): | |||||
| axis = params['axes'][0] if len(params['axes']) == 1 else tuple(params['axes']) | |||||
| else: | |||||
| axis = tuple() | |||||
| return {'values': {'axis': axis}} | |||||
| @@ -8,5 +8,8 @@ | |||||
| "onnx::GlobalAveragePool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.global_pool_mapper.GlobalPoolMapper", | "onnx::GlobalAveragePool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.global_pool_mapper.GlobalPoolMapper", | ||||
| "onnx::Flatten": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.flatten_mapper.FlattenMapper", | "onnx::Flatten": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.flatten_mapper.FlattenMapper", | ||||
| "onnx::Add": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.add_mapper.AddMapper", | "onnx::Add": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.add_mapper.AddMapper", | ||||
| "onnx::Pad": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.pad_mapper.PadMapper" | |||||
| "onnx::Pad": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.pad_mapper.PadMapper", | |||||
| "onnx::ReduceMean": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reduce_mean_mapper.ReduceMeanMapper", | |||||
| "onnx::Concat": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.concat_mapper.ConcatMapper", | |||||
| "onnx::Clip": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper" | |||||
| } | } | ||||
| @@ -96,6 +96,21 @@ class ReportGenerator(metaclass=abc.ABCMeta): | |||||
| f"[UnConvert] '{unconverted_operator_name}' didn't convert." | f"[UnConvert] '{unconverted_operator_name}' didn't convert." | ||||
| return content | return content | ||||
| @staticmethod | |||||
| def _get_unsupported_params(num_line, code_line): | |||||
| """Get unsupported params in converted operator.""" | |||||
| if 'UNSUPPORTED' in code_line: | |||||
| unsupported_params = re.findall(r"(.*).*[=][{]SUPPORTED", code_line) | |||||
| unsupported_msg = re.findall(r".*UNSUPPORTED[:] (.*)[}]", code_line) | |||||
| location = [f"{num_line + 1}", f"{code_line.index('UNSUPPORTED') + 1}"] | |||||
| unsupported_params_info = \ | |||||
| f"line {':'.join(location)}: " \ | |||||
| f"[Unsupported params] {unsupported_params}: {unsupported_msg}." | |||||
| else: | |||||
| unsupported_params_info = None | |||||
| return unsupported_params_info | |||||
| def gen_report(self, code: str): | def gen_report(self, code: str): | ||||
| """ | """ | ||||
| Generate report. | Generate report. | ||||
| @@ -111,6 +126,7 @@ class ReportGenerator(metaclass=abc.ABCMeta): | |||||
| num_unconverted_operator = 0 | num_unconverted_operator = 0 | ||||
| num_converted_operator = 0 | num_converted_operator = 0 | ||||
| converted_operator = None | |||||
| self._content = self._extra['start'] | self._content = self._extra['start'] | ||||
| for num_line in range(0, num_all_lines): | for num_line in range(0, num_all_lines): | ||||
| code_line = code_lines[num_line] | code_line = code_lines[num_line] | ||||
| @@ -132,6 +148,14 @@ class ReportGenerator(metaclass=abc.ABCMeta): | |||||
| if converted_operator: | if converted_operator: | ||||
| num_converted_operator += 1 | num_converted_operator += 1 | ||||
| info_unsupported_params = self._get_unsupported_params( | |||||
| num_line, | |||||
| code_line | |||||
| ) | |||||
| if info_unsupported_params: | |||||
| self._content = f"{NEW_LINE}".join((self._content, | |||||
| info_unsupported_params)) | |||||
| self._content = f"{NEW_LINE}".join((self._content, self._extra['end'])) | self._content = f"{NEW_LINE}".join((self._content, self._extra['end'])) | ||||
| converted_rate = \ | converted_rate = \ | ||||
| @@ -289,12 +289,16 @@ class GraphNode(abc.ABC): | |||||
| self._op_in_ms = None | self._op_in_ms = None | ||||
| # Params in mindspore. | # Params in mindspore. | ||||
| self._params_in_ms = dict() | self._params_in_ms = dict() | ||||
| # Settings in mindspore. | |||||
| self._settings_in_ms = dict() | |||||
| # Node type of current node, e.g. class, module, operation. | # Node type of current node, e.g. class, module, operation. | ||||
| self._node_type = None | self._node_type = None | ||||
| # Tag name on tree. | # Tag name on tree. | ||||
| self._tag_on_tree = None | self._tag_on_tree = None | ||||
| # Function, class or operation needed args. | # Function, class or operation needed args. | ||||
| self._args_in_code = dict() | self._args_in_code = dict() | ||||
| # Operation needed settings. | |||||
| self._settings_in_code = dict() | |||||
| # Variable name declared in init block. | # Variable name declared in init block. | ||||
| self._variable_name = None | self._variable_name = None | ||||
| # Output variable name declared in construct block. | # Output variable name declared in construct block. | ||||
| @@ -364,6 +368,27 @@ class GraphNode(abc.ABC): | |||||
| """ | """ | ||||
| self._args_in_code = args | self._args_in_code = args | ||||
| @property | |||||
| def settings_in_code(self): | |||||
| """ | |||||
| Settings in code. | |||||
| Returns: | |||||
| dict, settings. | |||||
| """ | |||||
| return self._settings_in_code | |||||
| @settings_in_code.setter | |||||
| def settings_in_code(self, settings): | |||||
| """ | |||||
| Settings in code. | |||||
| Args: | |||||
| settings(dict): Settings. | |||||
| """ | |||||
| self._settings_in_code = settings | |||||
| @property | @property | ||||
| def input_shape(self): | def input_shape(self): | ||||
| """ | """ | ||||
| @@ -567,14 +592,16 @@ class GraphNode(abc.ABC): | |||||
| params.update({"input_shape": self.input_shape, | params.update({"input_shape": self.input_shape, | ||||
| "output_shape": self.output_shape}) | "output_shape": self.output_shape}) | ||||
| op_name_in_mindspore, ms_params = mapper.convert(op_name=self.op_name, | |||||
| params=params, | |||||
| weights=self._weight) | |||||
| op_name_in_mindspore, ms_params, ms_settings = mapper.convert(op_name=self.op_name, | |||||
| params=params, | |||||
| weights=self._weight) | |||||
| if op_name_in_mindspore: | if op_name_in_mindspore: | ||||
| self._op_in_ms = op_name_in_mindspore | self._op_in_ms = op_name_in_mindspore | ||||
| self._params_in_ms = ms_params | self._params_in_ms = ms_params | ||||
| self._settings_in_ms = ms_settings | |||||
| else: | else: | ||||
| self._op_in_ms = self._op_name | self._op_in_ms = self._op_name | ||||
| self._params_in_ms = self._op_params | self._params_in_ms = self._op_params | ||||
| self._settings_in_ms = dict() | |||||
| return self._op_in_ms, self._params_in_ms | |||||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||||
| @@ -18,8 +18,9 @@ from copy import deepcopy | |||||
| from .base import GraphNode | from .base import GraphNode | ||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | ||||
| SEPARATOR_IN_ONNX_OP | |||||
| SEPARATOR_IN_ONNX_OP, InputType | |||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| from ...common.exceptions import NodeInputTypeNotSupport | |||||
| class PyTorchGraphNode(GraphNode): | class PyTorchGraphNode(GraphNode): | ||||
| @@ -186,6 +187,8 @@ class PyTorchGraphNode(GraphNode): | |||||
| self._opt_var_name = output_var | self._opt_var_name = output_var | ||||
| args = self.args_in_code | args = self.args_in_code | ||||
| settings = self.settings_in_code | |||||
| if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): | if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): | ||||
| args.update({"input_shape": self.input_shape, | args.update({"input_shape": self.input_shape, | ||||
| "output_shape": self.output_shape}) | "output_shape": self.output_shape}) | ||||
| @@ -193,16 +196,49 @@ class PyTorchGraphNode(GraphNode): | |||||
| if self._node_type == NodeType.OPERATION.value: | if self._node_type == NodeType.OPERATION.value: | ||||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | ||||
| for k, v in args.items()]) | for k, v in args.items()]) | ||||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(ipt_args_in_construct, | |||||
| settings) | |||||
| else: | else: | ||||
| # When it's type is module, class or func, | # When it's type is module, class or func, | ||||
| # it's not necessary to replace var. | # it's not necessary to replace var. | ||||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | ||||
| for k, v in args.items()]) | for k, v in args.items()]) | ||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| declare = f"self.{self._variable_name} = {operator}({expr})" | declare = f"self.{self._variable_name} = {operator}({expr})" | ||||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_in_construct})" | |||||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | |||||
| return declare, call | return declare, call | ||||
| @staticmethod | |||||
| def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): | |||||
| """ | |||||
| Generate input with args and settings in construct. | |||||
| Args: | |||||
| ipt_args_in_construct(str): input args in construct. | |||||
| settings(dict): settings in operator. | |||||
| """ | |||||
| if settings.get('input_type'): | |||||
| input_type = settings['input_type'] | |||||
| if input_type == InputType.TENSOR.value: | |||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| elif input_type == InputType.LIST.value: | |||||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct})" | |||||
| else: | |||||
| raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") | |||||
| else: | |||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| if settings.get('values'): | |||||
| settings_value = settings['values'] | |||||
| if settings_value: | |||||
| settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) | |||||
| ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) | |||||
| return ipt_args_settings_in_construct | |||||
| def to_ir(self): | def to_ir(self): | ||||
| """ | """ | ||||
| No need to implement for now. | No need to implement for now. | ||||
| @@ -266,17 +302,20 @@ class PyTorchGraphNode(GraphNode): | |||||
| self._args_in_code = dict() | self._args_in_code = dict() | ||||
| for arg, value in args.items(): | for arg, value in args.items(): | ||||
| self._args_in_code[self._get_arg_name(arg)] = value | self._args_in_code[self._get_arg_name(arg)] = value | ||||
| return None, None | |||||
| return None, None, None | |||||
| if not self.transformed: | if not self.transformed: | ||||
| _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | |||||
| _, _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | |||||
| for arg, value in self._params_in_ms.items(): | for arg, value in self._params_in_ms.items(): | ||||
| self._args_in_code[self._get_arg_name(arg)] = value | self._args_in_code[self._get_arg_name(arg)] = value | ||||
| for arg, value in self._settings_in_ms.items(): | |||||
| self._settings_in_code[arg] = value | |||||
| self.transformed = True | self.transformed = True | ||||
| return self._op_in_ms, self._params_in_ms | |||||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||||
| def froze_node_type_and_module_name(self, node_type, module_name): | def froze_node_type_and_module_name(self, node_type, module_name): | ||||
| """ | """ | ||||
| @@ -144,6 +144,18 @@ class TestHierarchicalTree: | |||||
| if report_folder: | if report_folder: | ||||
| shutil.rmtree(report_folder) | shutil.rmtree(report_folder) | ||||
| @staticmethod | |||||
| def _create_node(key, val, weight, input_shape, output_shape): | |||||
| """Create node.""" | |||||
| node = PyTorchGraphNode(weight=weight) | |||||
| node.add_input_and_output_shape(input_shape, output_shape) | |||||
| node.tag = key.split('/')[-1] if len(key.split('/')) > 1 else key | |||||
| node.op_name = val['op_name'] if val.get('op_name') else None | |||||
| node.precursor_nodes = val['precursor_nodes'] if val.get('precursor_nodes') else [] | |||||
| node.successor_nodes = val['successor_nodes'] if val.get('successor_nodes') else [] | |||||
| node.node_type = val['node_type'] if val.get('node_type') else None | |||||
| return node | |||||
| @staticmethod | @staticmethod | ||||
| def _create_tree(get_raw_params, params): | def _create_tree(get_raw_params, params): | ||||
| """Create tree.""" | """Create tree.""" | ||||
| @@ -154,13 +166,7 @@ class TestHierarchicalTree: | |||||
| get_raw_params.return_value = val['op_params'] if val.get('op_params') else dict() | get_raw_params.return_value = val['op_params'] if val.get('op_params') else dict() | ||||
| weight = val['weight'] if val.get('weight') else None | weight = val['weight'] if val.get('weight') else None | ||||
| node = PyTorchGraphNode(weight=weight) | |||||
| node.add_input_and_output_shape(input_shape, output_shape) | |||||
| node.tag = key.split('/')[-1] if len(key.split('/')) > 1 else key | |||||
| node.op_name = val['op_name'] if val.get('op_name') else None | |||||
| node.precursor_nodes = val['precursor_nodes'] if val.get('precursor_nodes') else [] | |||||
| node.successor_nodes = val['successor_nodes'] if val.get('successor_nodes') else [] | |||||
| node.node_type = val['node_type'] if val.get('node_type') else None | |||||
| node = TestHierarchicalTree._create_node(key, val, weight, input_shape, output_shape) | |||||
| tree.create_node( | tree.create_node( | ||||
| tag=node.tag, | tag=node.tag, | ||||
| @@ -13,15 +13,60 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Test all operator mappers on transformation from pytorch to mindspore.""" | """Test all operator mappers on transformation from pytorch to mindspore.""" | ||||
| import numpy as np | |||||
| import pytest | import pytest | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from tests.utils import mindspore | |||||
| class TestMappers: | class TestMappers: | ||||
| """Test Mappers.""" | """Test Mappers.""" | ||||
| @pytest.mark.parametrize('params', [{ | @pytest.mark.parametrize('params', [{ | ||||
| 'input': {'op_name': 'onnx::Conv', | |||||
| 'params': {'dilations': [1, 1], | |||||
| 'group': 1, | |||||
| 'pads': [1, 2, 3, 4], | |||||
| 'strides': [1, 1]}, | |||||
| 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 1, 1], dtype=np.int32))}}, | |||||
| 'expected_output': {'converter_name': 'nn.Conv2d', | |||||
| 'converted_params': {'in_channels': 3, | |||||
| 'out_channels': 64, | |||||
| 'kernel_size': (1, 1), | |||||
| 'stride': (1, 1), | |||||
| 'padding': (1, 3, 2, 4), | |||||
| 'pad_mode': '\"pad\"', | |||||
| 'dilation': (1, 1), | |||||
| 'group': 1}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Conv', | |||||
| 'params': {'dilations': [1, 1], | |||||
| 'group': 1, | |||||
| 'pads': [0, 0, 0, 0], | |||||
| 'strides': [1, 1]}, | |||||
| 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 2, 2], dtype=np.int32))}}, | |||||
| 'expected_output': {'converter_name': 'nn.Conv2d', | |||||
| 'converted_params': {'in_channels': 3, | |||||
| 'out_channels': 64, | |||||
| 'kernel_size': (2, 2), | |||||
| 'stride': (1, 1), | |||||
| 'padding': 0, | |||||
| 'pad_mode': '\"valid\"', | |||||
| 'dilation': (1, 1), | |||||
| 'group': 1}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Gemm', | |||||
| 'params': dict(), | |||||
| 'weights': {'weight': mindspore.Tensor(np.zeros([10, 3], dtype=np.int32)), | |||||
| 'bias': mindspore.Tensor(np.zeros([10, 1], dtype=np.int32))}}, | |||||
| 'expected_output': {'converter_name': 'nn.Dense', | |||||
| 'converted_params': {'in_channels': 3, | |||||
| 'out_channels': 10, | |||||
| 'has_bias': True}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::BatchNormalization', | 'input': {'op_name': 'onnx::BatchNormalization', | ||||
| 'params': {'epsilon': 1e-5, | 'params': {'epsilon': 1e-5, | ||||
| 'momentum': 0.9, | 'momentum': 0.9, | ||||
| @@ -30,13 +75,15 @@ class TestMappers: | |||||
| 'expected_output': {'converter_name': 'nn.BatchNorm2d', | 'expected_output': {'converter_name': 'nn.BatchNorm2d', | ||||
| 'converted_params': {'num_features': 6, | 'converted_params': {'num_features': 6, | ||||
| 'eps': 1e-5, | 'eps': 1e-5, | ||||
| 'momentum': 0.9}} | |||||
| 'momentum': 0.9}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Relu', | 'input': {'op_name': 'onnx::Relu', | ||||
| 'params': dict(), | 'params': dict(), | ||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'nn.ReLU', | 'expected_output': {'converter_name': 'nn.ReLU', | ||||
| 'converted_params': dict()} | |||||
| 'converted_params': dict(), | |||||
| 'converted_settings': dict()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::MaxPool', | 'input': {'op_name': 'onnx::MaxPool', | ||||
| 'params': {'kernel_shape': [3, 3], | 'params': {'kernel_shape': [3, 3], | ||||
| @@ -46,7 +93,8 @@ class TestMappers: | |||||
| 'expected_output': {'converter_name': 'nn.MaxPool2d', | 'expected_output': {'converter_name': 'nn.MaxPool2d', | ||||
| 'converted_params': {'kernel_size': (3, 3), | 'converted_params': {'kernel_size': (3, 3), | ||||
| 'stride': (2, 2), | 'stride': (2, 2), | ||||
| 'pad_mode': '"same"'}} | |||||
| 'pad_mode': '"same"'}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::AveragePool', | 'input': {'op_name': 'onnx::AveragePool', | ||||
| 'params': {'kernel_shape': [3, 3], | 'params': {'kernel_shape': [3, 3], | ||||
| @@ -56,31 +104,120 @@ class TestMappers: | |||||
| 'expected_output': {'converter_name': 'nn.AvgPool2d', | 'expected_output': {'converter_name': 'nn.AvgPool2d', | ||||
| 'converted_params': {'kernel_size': (3, 3), | 'converted_params': {'kernel_size': (3, 3), | ||||
| 'stride': (2, 2), | 'stride': (2, 2), | ||||
| 'pad_mode': '"same"'}} | |||||
| 'pad_mode': '"same"'}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::GlobalAveragePool', | 'input': {'op_name': 'onnx::GlobalAveragePool', | ||||
| 'params': {'input_shape': (1, 3, 10, 10), | 'params': {'input_shape': (1, 3, 10, 10), | ||||
| 'output_shape': (1, 3, 1, 1)}, | 'output_shape': (1, 3, 1, 1)}, | ||||
| 'weights': ''}, | 'weights': ''}, | ||||
| 'expected_output': {'converter_name': 'nn.AvgPool2d', | 'expected_output': {'converter_name': 'nn.AvgPool2d', | ||||
| 'converted_params': {'kernel_size': (10, 10)}} | |||||
| 'converted_params': {'kernel_size': (10, 10)}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Flatten', | 'input': {'op_name': 'onnx::Flatten', | ||||
| 'params': dict(), | 'params': dict(), | ||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'nn.Flatten', | 'expected_output': {'converter_name': 'nn.Flatten', | ||||
| 'converted_params': dict()} | |||||
| 'converted_params': dict(), | |||||
| 'converted_settings': dict()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Add', | 'input': {'op_name': 'onnx::Add', | ||||
| 'params': dict(), | 'params': dict(), | ||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'P.TensorAdd', | 'expected_output': {'converter_name': 'P.TensorAdd', | ||||
| 'converted_params': dict()} | |||||
| 'converted_params': dict(), | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Pad', | |||||
| 'params': {'pads': [0, 1, 2, 3], | |||||
| 'value': 0, | |||||
| 'mode': 'constant'}, | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': 'nn.Pad', | |||||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||||
| 'mode': '\"CONSTANT\"'}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Pad', | |||||
| 'params': {'pads': [0, 1, 2, 3], | |||||
| 'mode': 'reflect'}, | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': 'nn.Pad', | |||||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||||
| 'mode': '\"REFLECT\"'}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Pad', | |||||
| 'params': {'pads': [0, 1, 2, 3], | |||||
| 'value': 1, | |||||
| 'mode': 'constant'}, | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': 'nn.Pad', | |||||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||||
| 'mode': '{UNSUPPORTED: value is NOT 0}\"CONSTANT\"'}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Pad', | |||||
| 'params': {'pads': [0, 1, 2, 3], | |||||
| 'mode': 'edge'}, | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': 'nn.Pad', | |||||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||||
| 'mode': '{UNSUPPORTED: \"edge\"}\"UNKNOWN\"'}, | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::ReduceMean', | |||||
| 'params': {'keepdims': 0, | |||||
| 'axes': [1, 2]}, | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': 'P.ReduceMean', | |||||
| 'converted_params': {'keep_dims': False}, | |||||
| 'converted_settings': {'values': {'axis': (1, 2)}}} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::ReduceMean', | |||||
| 'params': {'keepdims': 1, | |||||
| 'axes': [1]}, | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': 'P.ReduceMean', | |||||
| 'converted_params': {'keep_dims': True}, | |||||
| 'converted_settings': {'values': {'axis': 1}}} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Concat', | |||||
| 'params': {'axis': 0}, | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': 'P.Concat', | |||||
| 'converted_params': {'axis': 0}, | |||||
| 'converted_settings': {'input_type': "list"}} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Clip', | |||||
| 'params': {'max': 6, | |||||
| 'min': 0}, | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': 'nn.ReLU6', | |||||
| 'converted_params': dict(), | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Clip', | |||||
| 'params': dict(), | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': 'nn.ReLU', | |||||
| 'converted_params': dict(), | |||||
| 'converted_settings': dict()} | |||||
| }, { | |||||
| 'input': {'op_name': 'onnx::Clip', | |||||
| 'params': {'max': 3, | |||||
| 'min': 2}, | |||||
| 'weights': dict()}, | |||||
| 'expected_output': {'converter_name': None, | |||||
| 'converted_params': dict(), | |||||
| 'converted_settings': dict()} | |||||
| }]) | }]) | ||||
| def test_mapper(self, params): | def test_mapper(self, params): | ||||
| """Test mapper function.""" | """Test mapper function.""" | ||||
| mapper = ONNXToMindSporeMapper() | mapper = ONNXToMindSporeMapper() | ||||
| converter_name, converted_params = \ | |||||
| converter_name, converted_params, converted_settings = \ | |||||
| mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) | mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) | ||||
| assert params['expected_output']['converter_name'] == converter_name | assert params['expected_output']['converter_name'] == converter_name | ||||
| assert params['expected_output']['converted_params'] == converted_params | assert params['expected_output']['converted_params'] == converted_params | ||||
| assert params['expected_output']['converted_settings'] == converted_settings | |||||
| @@ -28,3 +28,7 @@ class Tensor: | |||||
| def __repr__(self): | def __repr__(self): | ||||
| return str(self.asnumpy()) | return str(self.asnumpy()) | ||||
| def numpy(self): | |||||
| """Get value in numpy format, the torch format.""" | |||||
| return np.array(self._value) | |||||