From: @ghty0625 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -304,8 +304,30 @@ class NewFragment: | |||||
| """ | """ | ||||
| rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]} | rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]} | ||||
| if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data: | if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data: | ||||
| rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join( | |||||
| data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) | |||||
| group_inputs = ExchangeMessageKeywords.VariableScope.value.GROUP_INPUTS.value | |||||
| if group_inputs in data: | |||||
| input_tuple_list = [] | |||||
| tuple_index = 0 | |||||
| tuple_id = 0 | |||||
| while tuple_index < len(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]): | |||||
| if tuple_id < len(data[group_inputs]) and tuple_index in data[group_inputs][tuple_id]: | |||||
| tuple_added = ", ".join(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] | |||||
| [data[group_inputs][tuple_id][0]: | |||||
| data[group_inputs][tuple_id][-1]+1]) | |||||
| tuple_added = f"({tuple_added})" | |||||
| input_tuple_list.append(tuple_added) | |||||
| tuple_index = data[group_inputs][tuple_id][-1]+1 | |||||
| tuple_id += 1 | |||||
| continue | |||||
| input_tuple_list.append(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] | |||||
| [tuple_index]) | |||||
| tuple_index += 1 | |||||
| rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = \ | |||||
| ", ".join(input_tuple_list) | |||||
| else: | |||||
| rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join( | |||||
| data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) | |||||
| if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: | if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: | ||||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) | rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) | ||||
| if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data: | if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data: | ||||
| @@ -108,6 +108,7 @@ class ExchangeMessageKeywords(Enum): | |||||
| WEIGHTS = "weights" | WEIGHTS = "weights" | ||||
| TRAINABLE_PARAMS = "trainable_params" | TRAINABLE_PARAMS = "trainable_params" | ||||
| PARAMETERS_DECLARED = "parameters" | PARAMETERS_DECLARED = "parameters" | ||||
| GROUP_INPUTS = "group_inputs" | |||||
| BINARY_HEADER_PYTORCH_FILE = \ | BINARY_HEADER_PYTORCH_FILE = \ | ||||
| @@ -21,7 +21,7 @@ from typing import List | |||||
| from importlib import import_module | from importlib import import_module | ||||
| from importlib.util import find_spec | from importlib.util import find_spec | ||||
| from functools import partial | from functools import partial | ||||
| from google.protobuf.internal import api_implementation | |||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ | from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ | ||||
| save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info | save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info | ||||
| @@ -35,7 +35,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCrea | |||||
| BadParamError | BadParamError | ||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory | from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory | ||||
| from google.protobuf.internal import api_implementation | |||||
| check_common_dependency_integrity = partial(check_dependency_integrity, | check_common_dependency_integrity = partial(check_dependency_integrity, | ||||
| "onnx", "onnxruntime", "onnxoptimizer") | "onnx", "onnxruntime", "onnxoptimizer") | ||||
| @@ -267,7 +267,8 @@ def main_graph_base_converter(file_config): | |||||
| if api_implementation.Type() != 'cpp' or os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION') != 'cpp': | if api_implementation.Type() != 'cpp' or os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION') != 'cpp': | ||||
| log_console.warning("Protobuf is currently implemented in \"Python\". " | log_console.warning("Protobuf is currently implemented in \"Python\". " | ||||
| "The conversion process may take a long time. Please use the \"C++\" backend version.") | |||||
| "The conversion process may take a long time. " | |||||
| "Please use `export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp` to enable cpp backend.") | |||||
| graph_path = file_config['model_file'] | graph_path = file_config['model_file'] | ||||
| frame_type = get_framework_type(graph_path) | frame_type = get_framework_type(graph_path) | ||||
| @@ -0,0 +1,108 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Mapper module.""" | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| class LSTMMapper(ONNXToMindSporeMapper): | |||||
| """LSTM mapper.""" | |||||
| @staticmethod | |||||
| def _operation_name_in_ms(*args, **kwargs): | |||||
| return "nn.LSTM" | |||||
| @staticmethod | |||||
| def _convert_params(**kwargs): | |||||
| """convert params""" | |||||
| weights = kwargs["weights"] | |||||
| input_weights = LSTMMapper._find_val_by_index(0, weights) | |||||
| embed_dim = input_weights.shape[2] | |||||
| params = kwargs['params'] | |||||
| output_shape_list = kwargs.get("params").get("output_shape") | |||||
| output_shape = output_shape_list[0].node_output_shape | |||||
| # Here the first element determine if the lstm is bidirectional | |||||
| # `1` means unidirectional. `2` means bidirectional | |||||
| if output_shape[1] == 2: | |||||
| return { | |||||
| "input_size": embed_dim, | |||||
| "hidden_size": params["hidden_size"], | |||||
| "bidirectional": True | |||||
| } | |||||
| return { | |||||
| "input_size": embed_dim, | |||||
| "hidden_size": params["hidden_size"] | |||||
| } | |||||
| @staticmethod | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| """generate snippet template""" | |||||
| op = kwargs.get("operation") | |||||
| args = kwargs.get("converted_params", dict()) | |||||
| weights = kwargs.get("weights") | |||||
| output_shape_list = kwargs.get("raw_params").get("output_shape") | |||||
| output_shape = output_shape_list[0].node_output_shape | |||||
| output_reshape = (output_shape[0], output_shape[2], output_shape[1], output_shape[3]) | |||||
| trainable_params = kwargs.get("trainable_params", dict()) | |||||
| if not op: | |||||
| raise ValueError("Can not get MindSpore operation name.") | |||||
| variable_slot = "var_0" | |||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||||
| init_reshape = f"self.{{{variable_slot}}}_reshape = P.Reshape()" | |||||
| init_transpose = f"self.{{{variable_slot}}}_transpose = P.Transpose()" | |||||
| init_cast = f"self.{{{variable_slot}}}_cast = P.Cast()" | |||||
| construct_template = f"opt_{{{variable_slot}}}, (opt_{{{variable_slot}}}_h, " \ | |||||
| f"opt_{{{variable_slot}}}_c) = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | |||||
| construct_template_cast = f"opt_{{{variable_slot}}} = " \ | |||||
| f"self.{{{variable_slot}}}_cast(" \ | |||||
| f"opt_{{{variable_slot}}}, mindspore.float32)" | |||||
| construct_template_reshape = f"opt_{{{variable_slot}}} = " \ | |||||
| f"self.{{{variable_slot}}}_reshape(" \ | |||||
| f"opt_{{{variable_slot}}}, {output_reshape})" | |||||
| construct_template_transpose = f"opt_{{{variable_slot}}} = " \ | |||||
| f"self.{{{variable_slot}}}_transpose(" \ | |||||
| f"opt_{{{variable_slot}}}, (0, 2, 1, 3))" | |||||
| template = { | |||||
| variable_slot: { | |||||
| TemplateKeywords.INIT.value: [init_template, init_cast, init_reshape, init_transpose], | |||||
| TemplateKeywords.CONSTRUCT.value: [construct_template, | |||||
| construct_template_cast, | |||||
| construct_template_reshape, | |||||
| construct_template_transpose] | |||||
| } | |||||
| } | |||||
| exchange_msg = { | |||||
| variable_slot: { | |||||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||||
| ExchangeMessageKeywords.VariableScope.value.GROUP_INPUTS.value: [(1, 2)], | |||||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||||
| } | |||||
| } | |||||
| outputs_list = [f"opt_{{{variable_slot}}}", f"opt_{{{variable_slot}}}_h", f"opt_{{{variable_slot}}}_c"] | |||||
| outputs_mapping = ((0, 0), (1, 1), (2, 2),) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @@ -13,7 +13,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| @@ -36,11 +35,56 @@ class ConcatMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _generate_snippet_template(**kwargs): | def _generate_snippet_template(**kwargs): | ||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| weights = kwargs.get("weights") | |||||
| op = kwargs.get("operation") | |||||
| args = kwargs.get("converted_params", dict()) | |||||
| trainable_params = kwargs.get("trainable_params", dict()) | |||||
| if not op: | |||||
| raise ValueError("Can not get MindSpore operation name.") | |||||
| variable_slot = "var_0" | variable_slot = "var_0" | ||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | ||||
| f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}},))" | f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}},))" | ||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||||
| TemplateKeywords.CONSTRUCT.value) | |||||
| if weights: | |||||
| tensor = ConcatMapper._find_val_by_index(0, weights) | |||||
| weight_shape = tensor.shape | |||||
| weight_type = tensor.dtype | |||||
| args["weight_shape"] = weight_shape | |||||
| args["weight_type"] = weight_type | |||||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||||
| f"Parameter(Tensor(np.zeros({{weight_shape}}).astype(np.{{weight_type}})), " \ | |||||
| f"name=None)" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||||
| f"self.{{{variable_slot}}}_w))" | |||||
| template = { | |||||
| variable_slot: { | |||||
| TemplateKeywords.INIT.value: [init_template, init_tensor], | |||||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||||
| } | |||||
| } | |||||
| else: | |||||
| template = { | |||||
| variable_slot: { | |||||
| TemplateKeywords.INIT.value: [init_template], | |||||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||||
| } | |||||
| } | |||||
| exchange_msg = { | |||||
| variable_slot: { | |||||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||||
| } | |||||
| } | |||||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||||
| outputs_mapping = ((0, 0),) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | return template, exchange_msg, outputs_list, outputs_mapping | ||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Mapper module.""" | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| class FloorMapper(ONNXToMindSporeMapper): | |||||
| """Floor mapper.""" | |||||
| @staticmethod | |||||
| def _operation_name_in_ms(*args, **kwargs): | |||||
| return "P.Floor" | |||||
| @staticmethod | |||||
| def _convert_params(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @@ -37,6 +37,7 @@ class SliceMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _generate_snippet_template(**kwargs): | def _generate_snippet_template(**kwargs): | ||||
| """Generate snippet template.""" | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | ||||
| **kwargs) | **kwargs) | ||||
| op = kwargs.get("operation") | op = kwargs.get("operation") | ||||
| @@ -47,13 +48,14 @@ class SliceMapper(ONNXToMindSporeMapper): | |||||
| starts = SliceMapper._find_val_by_index(0, weights) | starts = SliceMapper._find_val_by_index(0, weights) | ||||
| ends = SliceMapper._find_val_by_index(1, weights) | ends = SliceMapper._find_val_by_index(1, weights) | ||||
| axes = SliceMapper._find_val_by_index(2, weights, np.array([i for i in range(len(ipt_shape))])) | |||||
| axes = SliceMapper._find_val_by_index(2, weights, np.array(list(range(len(ipt_shape))))) | |||||
| steps = SliceMapper._find_val_by_index(3, weights, np.array([1 for _ in range(len(ipt_shape))])) | steps = SliceMapper._find_val_by_index(3, weights, np.array([1 for _ in range(len(ipt_shape))])) | ||||
| if not op: | if not op: | ||||
| raise ValueError("Can not get MindSpore operation name.") | raise ValueError("Can not get MindSpore operation name.") | ||||
| if not weights: | if not weights: | ||||
| raise ValueError("Cannot get required params from slice.") | |||||
| raise ValueError("Can not get required params from slice.") | |||||
| if axes.shape != (1,): | if axes.shape != (1,): | ||||
| ordered_begin = sorted(zip(starts.tolist(), axes.tolist()), key=lambda x: x[1], reverse=False) | ordered_begin = sorted(zip(starts.tolist(), axes.tolist()), key=lambda x: x[1], reverse=False) | ||||
| @@ -71,9 +73,9 @@ class SliceMapper(ONNXToMindSporeMapper): | |||||
| end[axis] = min(ends.tolist()[0], end[axis]) | end[axis] = min(ends.tolist()[0], end[axis]) | ||||
| strides[axis] = steps.tolist()[0] | strides[axis] = steps.tolist()[0] | ||||
| args["begin"] = tuple(begin) | |||||
| args["end"] = tuple(end) | |||||
| args["strides"] = tuple(strides) | |||||
| args['begin'] = tuple(begin) | |||||
| args['end'] = tuple(end) | |||||
| args['strides'] = tuple(strides) | |||||
| variable_slot = "var_0" | variable_slot = "var_0" | ||||
| init_template = f"self.{{{variable_slot}}} = {op}()" | init_template = f"self.{{{variable_slot}}} = {op}()" | ||||
| @@ -84,7 +86,8 @@ class SliceMapper(ONNXToMindSporeMapper): | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | ||||
| f"self.{{{variable_slot}}}_begin, self.{{{variable_slot}}}_end, " \ | f"self.{{{variable_slot}}}_begin, self.{{{variable_slot}}}_end, " \ | ||||
| f"self.{{{variable_slot}}}_strides)" | f"self.{{{variable_slot}}}_strides)" | ||||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_begin, init_end, init_strides], | |||||
| template = reset_init_or_construct(template, variable_slot, | |||||
| [init_template, init_begin, init_end, init_strides], | |||||
| TemplateKeywords.INIT.value) | TemplateKeywords.INIT.value) | ||||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | template = reset_init_or_construct(template, variable_slot, [construct_template], | ||||
| TemplateKeywords.CONSTRUCT.value) | TemplateKeywords.CONSTRUCT.value) | ||||
| @@ -27,8 +27,8 @@ class SplitMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(**kwargs): | def _convert_params(**kwargs): | ||||
| axis = kwargs["params"]["axis"] | axis = kwargs["params"]["axis"] | ||||
| split = kwargs["params"]["split"] | |||||
| output_num = len(split) | |||||
| output_shape_list = kwargs["params"]["output_shape"] | |||||
| output_num = len(output_shape_list) | |||||
| return {"axis": axis, | return {"axis": axis, | ||||
| "output_num": output_num} | "output_num": output_num} | ||||
| @@ -43,9 +43,16 @@ class SplitMapper(ONNXToMindSporeMapper): | |||||
| weights = kwargs.get("weights") | weights = kwargs.get("weights") | ||||
| if not op: | if not op: | ||||
| raise ValueError("Can not get MindSpore operation name.") | raise ValueError("Can not get MindSpore operation name.") | ||||
| converted_params = kwargs["converted_params"] | |||||
| output_num = converted_params["output_num"] | |||||
| variable_slot = "var_0" | variable_slot = "var_0" | ||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | ||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||||
| slot_list = [f"opt_{{{variable_slot}}}"] | |||||
| for i in range(1, output_num): # Here `1` means the second output of the operator | |||||
| slot_list.append(f"opt_{{{variable_slot}}}_{i}") | |||||
| slot_final = ", ".join(slot_list) | |||||
| construct_template = f"{slot_final} = self.{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | ||||
| template = { | template = { | ||||
| variable_slot: { | variable_slot: { | ||||
| @@ -58,13 +65,16 @@ class SplitMapper(ONNXToMindSporeMapper): | |||||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | ||||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | ||||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | ||||
| ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value, | |||||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | ||||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | ||||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | ||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | ||||
| } | } | ||||
| } | } | ||||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||||
| outputs_mapping = ((0, 0),) | |||||
| outputs_list = slot_list | |||||
| outputs_mapping = [] | |||||
| for i in range(output_num): | |||||
| outputs_mapping.append((i, i)) | |||||
| outputs_mapping = tuple(outputs_mapping) | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | return template, exchange_msg, outputs_list, outputs_mapping | ||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Mapper module.""" | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| class SqueezeMapper(ONNXToMindSporeMapper): | |||||
| """Squeeze mapper.""" | |||||
| @staticmethod | |||||
| def _operation_name_in_ms(*args, **kwargs): | |||||
| return "P.Squeeze" | |||||
| @staticmethod | |||||
| def _convert_params(**kwargs): | |||||
| params = kwargs["params"] | |||||
| if len(params['axes']) == 1: | |||||
| return { | |||||
| "axis": params['axes'][0] | |||||
| } | |||||
| return { | |||||
| "axis": tuple(params['axes']) | |||||
| } | |||||
| @staticmethod | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @@ -30,5 +30,8 @@ | |||||
| "onnx::Erf": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.erf_mapper.ErfMapper", | "onnx::Erf": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.erf_mapper.ErfMapper", | ||||
| "onnx::Pow": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.pow_mapper.PowMapper", | "onnx::Pow": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.pow_mapper.PowMapper", | ||||
| "onnx::Einsum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.einsum_mapper.EinSumMapper", | "onnx::Einsum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.einsum_mapper.EinSumMapper", | ||||
| "onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper" | |||||
| "onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper", | |||||
| "onnx::LSTM": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.lstm_mapper.LSTMMapper", | |||||
| "onnx::Squeeze": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.squeeze_mapper.SqueezeMapper", | |||||
| "onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper" | |||||
| } | } | ||||
| @@ -23,7 +23,8 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_nod | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode | from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser | from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser | ||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser | from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser | ||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, NodeWeight | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, \ | |||||
| NodeWeight, NodeOutputShape | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher import generate_scope_name | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher import generate_scope_name | ||||
| NONE_SCOPE_OP = { | NONE_SCOPE_OP = { | ||||
| @@ -139,6 +140,9 @@ class OnnxGraph(Graph): | |||||
| super(OnnxGraph, self).build() | super(OnnxGraph, self).build() | ||||
| self._collect_input_shape_of_each_node() | self._collect_input_shape_of_each_node() | ||||
| for node_name in self._shape_dict: | |||||
| if len(self._shape_dict[node_name]) == 1: | |||||
| self._shape_dict[node_name] = self._shape_dict[node_name][0].node_output_shape | |||||
| def _collect_input_shape_of_each_node(self): | def _collect_input_shape_of_each_node(self): | ||||
| """ | """ | ||||
| @@ -156,11 +160,17 @@ class OnnxGraph(Graph): | |||||
| input_nodes[ipt].set_scope_name(node.scope_name) | input_nodes[ipt].set_scope_name(node.scope_name) | ||||
| node.precursor_nodes.insert(0, ipt) | node.precursor_nodes.insert(0, ipt) | ||||
| input_nodes[ipt].set_successor_nodes(node_name) | input_nodes[ipt].set_successor_nodes(node_name) | ||||
| self._shape_dict[ipt] = input_nodes[ipt].output_shape | |||||
| output_shape_single = NodeOutputShape(ipt, None, input_nodes[ipt].output_shape) | |||||
| if ipt not in self._shape_dict: | |||||
| self._shape_dict.setdefault(ipt, []).append(output_shape_single) | |||||
| ipt_shape = [] | ipt_shape = [] | ||||
| for p_nd in node.precursor_nodes: | |||||
| shp = self._shape_dict.get(p_nd) | |||||
| ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp) | |||||
| for ipt_nd_name in node.ir_node_inputs: | |||||
| for p_nd in node.precursor_nodes: | |||||
| shp_list = self._shape_dict.get(p_nd) | |||||
| for shp in shp_list: | |||||
| if ipt_nd_name == shp.node_opt_name: | |||||
| shp_single = shp.node_output_shape | |||||
| ipt_shape.append(tuple(shp_single) if isinstance(shp_single, list) else shp_single) | |||||
| self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape | self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape | ||||
| @@ -438,6 +438,15 @@ class OnnxDataLoader: | |||||
| for tensor in tensors: | for tensor in tensors: | ||||
| t = OnnxTensor(tensor) | t = OnnxTensor(tensor) | ||||
| self.tensors_dict[t.name] = t | self.tensors_dict[t.name] = t | ||||
| idx = 0 | |||||
| while idx < len(self.model.graph.output): | |||||
| cur_opt = self.model.graph.output[idx] | |||||
| if cur_opt.name not in self.output_nodes: | |||||
| self.model.graph.output.remove(cur_opt) | |||||
| continue | |||||
| idx += 1 | |||||
| self._global_context.onnx_tensors_collection = self.tensors_dict | self._global_context.onnx_tensors_collection = self.tensors_dict | ||||
| def _parse_node_output_shape(self): | def _parse_node_output_shape(self): | ||||
| @@ -465,7 +474,8 @@ class OnnxDataLoader: | |||||
| node_name = self.output_name_to_node_name[node_opt_name] | node_name = self.output_name_to_node_name[node_opt_name] | ||||
| if not node_name: | if not node_name: | ||||
| raise GraphInitError(msg=f"Cannot find where edge {node_opt_name} comes from.") | raise GraphInitError(msg=f"Cannot find where edge {node_opt_name} comes from.") | ||||
| self.node_output_shape_dict[node_name] = shape | |||||
| node_output_shape = NodeOutputShape(node_opt_name, node_name, shape) | |||||
| self.node_output_shape_dict.setdefault(node_name, []).append(node_output_shape) | |||||
| def get_node(self, node_name): | def get_node(self, node_name): | ||||
| """Get the OnnxNode instance by node name.""" | """Get the OnnxNode instance by node name.""" | ||||
| @@ -576,12 +586,16 @@ class OnnxDataLoader: | |||||
| """Create a PASS to optimize shape and reshape operations in ONNX ir graph.""" | """Create a PASS to optimize shape and reshape operations in ONNX ir graph.""" | ||||
| to_be_eliminated_op = {"Cast", "Concat", "Squeeze", "Unsqueeze", "Slice", | to_be_eliminated_op = {"Cast", "Concat", "Squeeze", "Unsqueeze", "Slice", | ||||
| "Gather", "Shape"} | "Gather", "Shape"} | ||||
| def _is_origin_inputs(node): | |||||
| for ipt in node.input_name_list: | |||||
| if ipt not in self.input_nodes: | |||||
| return False | |||||
| return True | |||||
| def _traceback_precursor_nodes_until_shape_op(node_ref): | def _traceback_precursor_nodes_until_shape_op(node_ref): | ||||
| nonlocal self | nonlocal self | ||||
| e_nodes = [] | e_nodes = [] | ||||
| node = self._nodes_dict[self.output_name_to_node_name[node_ref]] | node = self._nodes_dict[self.output_name_to_node_name[node_ref]] | ||||
| if node.op_type not in to_be_eliminated_op: | |||||
| if node.op_type not in to_be_eliminated_op or _is_origin_inputs(node): | |||||
| return e_nodes | return e_nodes | ||||
| e_nodes.append(node.name) | e_nodes.append(node.name) | ||||
| for ipt in node.input_name_list: | for ipt in node.input_name_list: | ||||
| @@ -646,3 +660,23 @@ class NodeWeight: | |||||
| @property | @property | ||||
| def location(self): | def location(self): | ||||
| return self._weight_location | return self._weight_location | ||||
| class NodeOutputShape: | |||||
| """Node output shape and its name.""" | |||||
| def __init__(self, node_opt_name, node_name, node_output_shape): | |||||
| self._node_opt_name = node_opt_name | |||||
| self._node_name = node_name | |||||
| self._node_output_shape = node_output_shape | |||||
| @property | |||||
| def node_opt_name(self): | |||||
| return self._node_opt_name | |||||
| @property | |||||
| def node_name(self): | |||||
| return self._node_name | |||||
| @property | |||||
| def node_output_shape(self): | |||||
| return self._node_output_shape | |||||