You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Mapper module."""
  16. import abc
  17. import importlib
  18. import json
  19. import os
  20. from typing import Dict
  21. from mindinsight.mindconverter.common.log import logger as log
  22. from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
  23. CONFIG_JSON = "onnx_to_ms.json"
  24. OPERATION_TABLE = os.path.join(
  25. os.path.abspath(os.path.dirname(__file__)),
  26. CONFIG_JSON
  27. )
  28. with open(OPERATION_TABLE) as file:
  29. # Load mapping table which key is operation name in ONNX and
  30. # value is corresponding module path.
  31. TABLE = json.load(file)
  32. # Define global func name.
  33. GET_OP_NAME = "_operation_name_in_ms"
  34. GET_OP_PARAMS = "_convert_params"
  35. GET_OP_WEIGHTS = "_convert_trained_weights"
  36. GET_OP_SETTINGS = "_convert_settings"
  37. GET_OP_TEMPLATE = "_generate_snippet_template"
  38. class Mapper(metaclass=abc.ABCMeta):
  39. """Mapper between third-party-operation and MindSpore."""
  40. @staticmethod
  41. @abc.abstractmethod
  42. def _operation_name_in_ms(*args, **kwargs):
  43. """Corresponding operation name in MindSpore."""
  44. @staticmethod
  45. @abc.abstractmethod
  46. def _convert_params(**kwargs):
  47. """Convert third party operation's param into MindSpore operation."""
  48. @staticmethod
  49. @abc.abstractmethod
  50. def _convert_trained_weights(**kwargs):
  51. """Convert third party operation's weights into MindSpore operation."""
  52. @staticmethod
  53. @abc.abstractmethod
  54. def _convert_settings(**kwargs):
  55. """Convert third party operation's params into MindSpore OP operator."""
  56. @classmethod
  57. @abc.abstractmethod
  58. def convert(cls, op_name: str, params: Dict, weights: Dict = None):
  59. """Convert third party operation's param into MindSpore operation."""
  60. @staticmethod
  61. @abc.abstractmethod
  62. def _generate_snippet_template(**kwargs):
  63. """Generate code template according to node info."""
  64. class ONNXToMindSporeMapper(Mapper, abc.ABC):
  65. """ONNX operation to MindSpore."""
  66. @classmethod
  67. def convert(cls, op_name: str, params: Dict, weights: Dict = None):
  68. """
  69. Convert third party operation's param into MindSpore operation.
  70. Args:
  71. op_name (str): Operation name in ONNX.
  72. params (dict): Params in onnx.
  73. weights (dict): Weights in onnx.
  74. Returns:
  75. Tuple[str, dict, dict], operation name and params and settings.
  76. """
  77. global TABLE
  78. module_name = TABLE.get(op_name)
  79. if not module_name:
  80. return None, dict(), None, dict()
  81. pos = module_name.rfind(".")
  82. try:
  83. converter = getattr(importlib.import_module(module_name[:pos]),
  84. module_name[pos + 1:])
  85. op_name_converter = getattr(converter, GET_OP_NAME)
  86. params_converter = getattr(converter, GET_OP_PARAMS)
  87. weights_converter = getattr(converter, GET_OP_WEIGHTS)
  88. settings_converter = getattr(converter, GET_OP_SETTINGS)
  89. except (ModuleNotFoundError,) as e:
  90. # If mapper can not be found, then skip it.
  91. err_msg = f"Converting {op_name} failed, see {str(e)}"
  92. log.error(err_msg)
  93. return None, dict(), None, dict()
  94. try:
  95. converter_name = op_name_converter(params=params, weights=weights, op_name=op_name)
  96. converted_params = params_converter(params=params, weights=weights)
  97. converted_weights = weights_converter(weights=weights) if weights else dict()
  98. converted_params.update(converted_weights)
  99. converted_settings = settings_converter(params=params, weights=weights)
  100. except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:
  101. err_msg = f"Converting {op_name} failed, see {str(e)}"
  102. log.error(err_msg)
  103. return None, dict(), None, dict()
  104. return converter_name, converted_params, converted_settings, converted_weights
  105. @staticmethod
  106. def _operation_name_in_ms(*args, **kwargs):
  107. raise NotImplementedError
  108. @staticmethod
  109. def _convert_params(**kwargs):
  110. raise NotImplementedError
  111. @staticmethod
  112. def _convert_trained_weights(**kwargs):
  113. raise NotImplementedError
  114. @staticmethod
  115. def _convert_settings(**kwargs):
  116. raise NotImplementedError
  117. @staticmethod
  118. def _generate_snippet_template(**kwargs):
  119. op = kwargs.get("operation")
  120. args = kwargs.get("converted_params")
  121. weights = kwargs.get("weights")
  122. if not op:
  123. raise ValueError("Can not get MindSpore operation name.")
  124. variable_slot = "var_0"
  125. init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
  126. construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
  127. f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})"
  128. template = {
  129. variable_slot: {
  130. TemplateKeywords.INIT.value: [init_template],
  131. TemplateKeywords.CONSTRUCT.value: [construct_template]
  132. }
  133. }
  134. exchange_msg = {
  135. variable_slot: {
  136. ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
  137. ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
  138. ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
  139. ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
  140. ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
  141. ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
  142. ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
  143. ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
  144. }
  145. }
  146. outputs_list = [f"opt_{{{variable_slot}}}"]
  147. outputs_mapping = ((0, 0),)
  148. return template, exchange_msg, outputs_list, outputs_mapping