diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mat_mul_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py similarity index 90% rename from mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mat_mul_mapper.py rename to mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py index fd088566..1cb26605 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mat_mul_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py @@ -23,7 +23,7 @@ class MatMulMapper(ONNXToMindSporeMapper): @staticmethod def _operation_name_in_ms(*args, **kwargs): - return "P.matmul" + return "nn.MatMul" @staticmethod def _convert_params(**kwargs): @@ -50,19 +50,19 @@ class MatMulMapper(ONNXToMindSporeMapper): variable_slot = "var_0" w_location = MatMulMapper._find_location_by_index(0, weights) - init_tensor_list = list() + init_template_list = [f"self.{{{variable_slot}}} = {op}()"] inputs_in_construct = [f"{{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}"] if w_location != -1: # Note: adding weight shape to args is now deprecated due to conflict of partial weights share processing. variable_slot_param_name = f"{variable_slot}/w" - init_tensor_list.append(f"self.{{{variable_slot}}}_w = {{{variable_slot_param_name}}}") + init_template_list.append(f"self.{{{variable_slot}}}_w = {{{variable_slot_param_name}}}") inputs_in_construct.insert(w_location, f"self.{{{variable_slot}}}_w") - construct_template = f"opt_{{{variable_slot}}} = {op}({', '.join(inputs_in_construct)})" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}({', '.join(inputs_in_construct)})" template = { variable_slot: { - TemplateKeywords.INIT.value: init_tensor_list, + TemplateKeywords.INIT.value: init_template_list, TemplateKeywords.CONSTRUCT.value: [construct_template] } } diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json index 19dcaa9e..0cb96d15 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json +++ b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json @@ -13,7 +13,7 @@ "onnx::Concat": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.concat_mapper.ConcatMapper", "onnx::Clip": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper", "onnx::Transpose": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.transpose_mapper.TransposeMapper", - "onnx::MatMul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mat_mul_mapper.MatMulMapper", + "onnx::MatMul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.mat_mul_mapper.MatMulMapper", "onnx::Softmax": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.softmax_mapper.SoftmaxMapper", "onnx::OneHot": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.one_hot_mapper.OneHotMapper", "onnx::Neg": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.neg_mapper.NegMapper",