From: @liuchongming74 Reviewed-by: @ouwenchang,@yelihua Signed-off-by: @yelihuatags/v1.2.0-rc1
| @@ -14,8 +14,6 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| @@ -34,18 +32,6 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| weights = kwargs.get("weights") | |||
| if not weights: | |||
| return Setting() | |||
| tensor, ref = None, "" | |||
| for t_name, t_value in weights.items(): | |||
| tensor = t_value | |||
| ref = t_name | |||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | |||
| dtype=get_dtype(tensor), reference=ref)) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| @@ -53,6 +39,8 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -61,14 +49,31 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| args["weight_shape"] = tensor.shape | |||
| args["weight_dtype"] = tensor.dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||
| f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))" | |||
| f"Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}}))" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||
| f"self.{{{variable_slot}}}_w)" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template, init_tensor], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -13,10 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting, Tensor, get_dtype | |||
| class AddMapper(ONNXToMindSporeMapper): | |||
| @@ -34,18 +32,6 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| weights = kwargs.get("weights") | |||
| if not weights: | |||
| return Setting() | |||
| tensor, ref = None, "" | |||
| for t_name, t_value in weights.items(): | |||
| tensor = t_value | |||
| ref = t_name | |||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | |||
| dtype=get_dtype(tensor), reference=ref)) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| @@ -53,6 +39,8 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -61,14 +49,31 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| args["bias_shape"] = tensor.shape | |||
| args["bias_dtype"] = tensor.dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = " \ | |||
| f"Tensor(np.random.uniform(0, 1, {tensor.shape}).astype(np.{tensor.dtype}))" | |||
| f"Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}}))" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||
| f"self.{{{variable_slot}}}_bias)" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_tensor], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template, init_tensor], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||