You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exceptions.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define custom exception."""
  16. import abc
  17. import sys
  18. from enum import unique, Enum
  19. from importlib import import_module
  20. from lib2to3.pgen2 import parse
  21. from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
  22. from mindinsight.utils.constant import ScriptConverterErrors
  23. from mindinsight.utils.exceptions import MindInsightException
  24. @unique
  25. class ConverterErrors(ScriptConverterErrors):
  26. """Converter error codes."""
  27. SCRIPT_NOT_SUPPORT = 1
  28. NODE_TYPE_NOT_SUPPORT = 2
  29. CODE_SYNTAX_ERROR = 3
  30. BASE_CONVERTER_FAIL = 000
  31. GRAPH_INIT_FAIL = 100
  32. SOURCE_FILES_SAVE_FAIL = 200
  33. GENERATOR_FAIL = 300
  34. SUB_GRAPH_SEARCHING_FAIL = 400
  35. class ScriptNotSupport(MindInsightException):
  36. """The script can not support to process."""
  37. def __init__(self, msg):
  38. super(ScriptNotSupport, self).__init__(ConverterErrors.SCRIPT_NOT_SUPPORT,
  39. msg,
  40. http_code=400)
  41. class NodeTypeNotSupport(MindInsightException):
  42. """The astNode can not support to process."""
  43. def __init__(self, msg):
  44. super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT,
  45. msg,
  46. http_code=400)
  47. class CodeSyntaxError(MindInsightException):
  48. """The CodeSyntaxError class definition."""
  49. def __init__(self, msg):
  50. super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR,
  51. msg,
  52. http_code=400)
  53. class MindConverterException(Exception):
  54. """MindConverter exception."""
  55. BASE_ERROR_CODE = None # ConverterErrors.BASE_CONVERTER_FAIL.value
  56. # ERROR_CODE should be declared in child exception.
  57. ERROR_CODE = None
  58. def __init__(self, **kwargs):
  59. """Initialization of MindInsightException."""
  60. user_msg = kwargs.get('user_msg', '')
  61. if isinstance(user_msg, str):
  62. user_msg = ' '.join(user_msg.split())
  63. super(MindConverterException, self).__init__()
  64. self.user_msg = user_msg
  65. self.root_exception_error_code = None
  66. def __str__(self):
  67. return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg)
  68. def __repr__(self):
  69. return self.__str__()
  70. def error_code(self):
  71. """"
  72. Calculate error code.
  73. code compose(2bytes)
  74. error: 16bits.
  75. num = 0xFFFF & error
  76. error_cods
  77. Returns:
  78. str, Hex string representing the composed MindConverter error code.
  79. """
  80. if self.root_exception_error_code:
  81. return self.root_exception_error_code
  82. if self.BASE_ERROR_CODE is None or self.ERROR_CODE is None:
  83. raise ValueError("MindConverterException has not been initialized.")
  84. num = 0xFFFF & self.ERROR_CODE # 0xFFFF & self.error.value
  85. error_code = f"{str(self.BASE_ERROR_CODE).zfill(3)}{hex(num)[2:].zfill(4).upper()}"
  86. return error_code
  87. @classmethod
  88. @abc.abstractmethod
  89. def raise_from(cls):
  90. """Raise from below exceptions."""
  91. @classmethod
  92. def normalize_error_msg(cls, error_msg):
  93. """Normalize error msg for common python exception."""
  94. if cls.BASE_ERROR_CODE is None or cls.ERROR_CODE is None:
  95. raise ValueError("MindConverterException has not been initialized.")
  96. num = 0xFFFF & cls.ERROR_CODE # 0xFFFF & self.error.value
  97. error_code = f"{str(cls.BASE_ERROR_CODE).zfill(3)}{hex(num)[2:].zfill(4).upper()}"
  98. return f"[{cls.__name__}] code: {error_code}, msg: {error_msg}"
  99. @classmethod
  100. def uniform_catcher(cls, msg: str = ""):
  101. """Uniform exception catcher."""
  102. def decorator(func):
  103. def _f(*args, **kwargs):
  104. try:
  105. res = func(*args, **kwargs)
  106. except cls.raise_from() as e:
  107. error = cls() if not msg else cls(msg=msg)
  108. detail_info = str(e)
  109. if not isinstance(e, MindConverterException):
  110. detail_info = cls.normalize_error_msg(str(e))
  111. log.error(error)
  112. log_console.error(detail_info)
  113. log.exception(e)
  114. sys.exit(0)
  115. except ModuleNotFoundError as e:
  116. detail_info = "Error detail: Required package not found, please check the runtime environment."
  117. log_console.error(f"{str(e)}\n{detail_info}")
  118. log.exception(e)
  119. sys.exit(0)
  120. return res
  121. return _f
  122. return decorator
  123. @classmethod
  124. def check_except(cls, msg):
  125. """Check except."""
  126. def decorator(func):
  127. def _f(*args, **kwargs):
  128. try:
  129. output = func(*args, **kwargs)
  130. except cls.raise_from() as e:
  131. error = cls(msg=msg)
  132. error_code = e.error_code() if isinstance(e, MindConverterException) else None
  133. error.root_exception_error_code = error_code
  134. log.error(msg)
  135. log.exception(e)
  136. raise error
  137. except Exception as e:
  138. log.error(msg)
  139. log.exception(e)
  140. raise e
  141. return output
  142. return _f
  143. return decorator
  144. class BaseConverterError(MindConverterException):
  145. """Base converter failed."""
  146. @unique
  147. class ErrCode(Enum):
  148. """Define error code of BaseConverterError."""
  149. UNKNOWN_ERROR = 0
  150. UNKNOWN_MODEL = 1
  151. PARAM_MISSING = 2
  152. BAD_PARAM = 3
  153. BASE_ERROR_CODE = ConverterErrors.BASE_CONVERTER_FAIL.value
  154. ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
  155. DEFAULT_MSG = "Failed to start base converter."
  156. def __init__(self, msg=DEFAULT_MSG):
  157. super(BaseConverterError, self).__init__(user_msg=msg)
  158. @classmethod
  159. def raise_from(cls):
  160. """Raise from exceptions below."""
  161. except_source = Exception, UnknownModelError, ParamMissingError, cls
  162. return except_source
  163. class UnknownModelError(BaseConverterError):
  164. """The unknown model error."""
  165. ERROR_CODE = BaseConverterError.ErrCode.UNKNOWN_MODEL.value
  166. def __init__(self, msg):
  167. super(UnknownModelError, self).__init__(msg=msg)
  168. @classmethod
  169. def raise_from(cls):
  170. return cls
  171. class ParamMissingError(BaseConverterError):
  172. """Define cli params missing error."""
  173. ERROR_CODE = BaseConverterError.ErrCode.PARAM_MISSING.value
  174. def __init__(self, msg):
  175. super(ParamMissingError, self).__init__(msg=msg)
  176. @classmethod
  177. def raise_from(cls):
  178. return cls
  179. class BadParamError(BaseConverterError):
  180. """Define cli bad params error."""
  181. ERROR_CODE = BaseConverterError.ErrCode.BAD_PARAM.value
  182. def __init__(self, msg):
  183. super(BadParamError, self).__init__(msg=msg)
  184. @classmethod
  185. def raise_from(cls):
  186. return cls
  187. class GraphInitError(MindConverterException):
  188. """The graph init fail error."""
  189. @unique
  190. class ErrCode(Enum):
  191. """Define error code of GraphInitError."""
  192. UNKNOWN_ERROR = 0
  193. MODEL_LOADING_ERROR = 1
  194. TF_RUNTIME_ERROR = 2
  195. MI_RUNTIME_ERROR = 3
  196. BASE_ERROR_CODE = ConverterErrors.GRAPH_INIT_FAIL.value
  197. ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
  198. DEFAULT_MSG = "Error occurred when init graph object."
  199. def __init__(self, msg=DEFAULT_MSG):
  200. super(GraphInitError, self).__init__(user_msg=msg)
  201. @classmethod
  202. def raise_from(cls):
  203. """Raise from exceptions below."""
  204. except_source = (FileNotFoundError,
  205. ModuleNotFoundError,
  206. ModelLoadingError,
  207. RuntimeIntegrityError,
  208. TypeError,
  209. ZeroDivisionError,
  210. RuntimeError,
  211. cls)
  212. return except_source
  213. class SourceFilesSaveError(MindConverterException):
  214. """The source files save fail error."""
  215. @unique
  216. class ErrCode(Enum):
  217. """Define error code of SourceFilesSaveError."""
  218. UNKNOWN_ERROR = 0
  219. NODE_INPUT_TYPE_NOT_SUPPORT = 1
  220. SCRIPT_GENERATE_FAIL = 2
  221. REPORT_GENERATE_FAIL = 3
  222. CKPT_GENERATE_FAIL = 4
  223. MAP_GENERATE_FAIL = 5
  224. MODEL_SAVE_FAIL = 6
  225. BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value
  226. ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
  227. DEFAULT_MSG = "Error occurred when save source files."
  228. def __init__(self, msg=DEFAULT_MSG):
  229. super(SourceFilesSaveError, self).__init__(user_msg=msg)
  230. @classmethod
  231. def raise_from(cls):
  232. """Raise from exceptions below."""
  233. except_source = (NodeInputTypeNotSupportError,
  234. ScriptGenerationError,
  235. ReportGenerationError,
  236. CheckPointGenerationError,
  237. WeightMapGenerationError,
  238. OnnxModelSaveError,
  239. IOError, cls)
  240. return except_source
  241. class ModelLoadingError(GraphInitError):
  242. """The model not support error."""
  243. ERROR_CODE = GraphInitError.ErrCode.MODEL_LOADING_ERROR.value
  244. def __init__(self, msg):
  245. super(ModelLoadingError, self).__init__(msg=msg)
  246. @classmethod
  247. def raise_from(cls):
  248. """Raise from exceptions below."""
  249. onnxruntime_error = getattr(import_module('onnxruntime.capi'), 'onnxruntime_pybind11_state')
  250. except_source = (RuntimeError,
  251. ModuleNotFoundError,
  252. ValueError,
  253. AssertionError,
  254. TypeError,
  255. OSError,
  256. ZeroDivisionError,
  257. onnxruntime_error.Fail,
  258. onnxruntime_error.InvalidArgument,
  259. onnxruntime_error.NoSuchFile,
  260. onnxruntime_error.NoModel,
  261. onnxruntime_error.EngineError,
  262. onnxruntime_error.RuntimeException,
  263. onnxruntime_error.InvalidProtobuf,
  264. onnxruntime_error.ModelLoaded,
  265. onnxruntime_error.NotImplemented,
  266. onnxruntime_error.InvalidGraph,
  267. onnxruntime_error.EPFail,
  268. cls)
  269. return except_source
  270. class TfRuntimeError(GraphInitError):
  271. """Catch tf runtime error."""
  272. ERROR_CODE = GraphInitError.ErrCode.TF_RUNTIME_ERROR.value
  273. DEFAULT_MSG = "Error occurred when init graph, TensorFlow runtime error."
  274. def __init__(self, msg=DEFAULT_MSG):
  275. super(TfRuntimeError, self).__init__(msg=msg)
  276. @classmethod
  277. def raise_from(cls):
  278. tf_error_module = import_module('tensorflow.python.framework.errors_impl')
  279. tf_error = getattr(tf_error_module, 'OpError')
  280. return tf_error, ValueError, RuntimeError, cls
  281. class RuntimeIntegrityError(GraphInitError):
  282. """Catch runtime error."""
  283. ERROR_CODE = GraphInitError.ErrCode.MI_RUNTIME_ERROR.value
  284. def __init__(self, msg):
  285. super(RuntimeIntegrityError, self).__init__(msg=msg)
  286. @classmethod
  287. def raise_from(cls):
  288. return RuntimeError, AttributeError, ImportError, ModuleNotFoundError, cls
  289. class NodeInputTypeNotSupportError(SourceFilesSaveError):
  290. """The node input type NOT support error."""
  291. ERROR_CODE = SourceFilesSaveError.ErrCode.NODE_INPUT_TYPE_NOT_SUPPORT.value
  292. def __init__(self, msg):
  293. super(NodeInputTypeNotSupportError, self).__init__(msg=msg)
  294. @classmethod
  295. def raise_from(cls):
  296. return ValueError, TypeError, IndexError, cls
  297. class ScriptGenerationError(SourceFilesSaveError):
  298. """The script generate fail error."""
  299. ERROR_CODE = SourceFilesSaveError.ErrCode.SCRIPT_GENERATE_FAIL.value
  300. def __init__(self, msg):
  301. super(ScriptGenerationError, self).__init__(msg=msg)
  302. @classmethod
  303. def raise_from(cls):
  304. """Raise from exceptions below."""
  305. except_source = (RuntimeError,
  306. parse.ParseError,
  307. AttributeError, cls)
  308. return except_source
  309. class ReportGenerationError(SourceFilesSaveError):
  310. """The report generate fail error."""
  311. ERROR_CODE = SourceFilesSaveError.ErrCode.REPORT_GENERATE_FAIL.value
  312. def __init__(self, msg):
  313. super(ReportGenerationError, self).__init__(msg=msg)
  314. @classmethod
  315. def raise_from(cls):
  316. """Raise from exceptions below."""
  317. return ZeroDivisionError, cls
  318. class CheckPointGenerationError(SourceFilesSaveError):
  319. """The checkpoint generate fail error."""
  320. ERROR_CODE = SourceFilesSaveError.ErrCode.CKPT_GENERATE_FAIL.value
  321. def __init__(self, msg):
  322. super(CheckPointGenerationError, self).__init__(msg=msg)
  323. @classmethod
  324. def raise_from(cls):
  325. """Raise from exceptions below."""
  326. return cls
  327. class WeightMapGenerationError(SourceFilesSaveError):
  328. """The weight names map generate fail error."""
  329. ERROR_CODE = SourceFilesSaveError.ErrCode.MAP_GENERATE_FAIL.value
  330. def __init__(self, msg):
  331. super(WeightMapGenerationError, self).__init__(msg=msg)
  332. @classmethod
  333. def raise_from(cls):
  334. """Raise from exception below."""
  335. return cls
  336. class OnnxModelSaveError(SourceFilesSaveError):
  337. """The onnx model save fail error."""
  338. ERROR_CODE = SourceFilesSaveError.ErrCode.MODEL_SAVE_FAIL.value
  339. def __init__(self, msg):
  340. super(OnnxModelSaveError, self).__init__(msg=msg)
  341. @classmethod
  342. def raise_from(cls):
  343. """Raise from exception below."""
  344. return cls
  345. class SubGraphSearchingError(MindConverterException):
  346. """Sub-graph searching exception."""
  347. @unique
  348. class ErrCode(Enum):
  349. """Define error code of SourceFilesSaveError."""
  350. BASE_ERROR = 0
  351. CANNOT_FIND_VALID_PATTERN = 1
  352. MODEL_NOT_SUPPORT = 2
  353. BASE_ERROR_CODE = ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value
  354. ERROR_CODE = ErrCode.BASE_ERROR.value
  355. DEFAULT_MSG = "Sub-Graph pattern searching fail."
  356. def __init__(self, msg=DEFAULT_MSG):
  357. super(SubGraphSearchingError, self).__init__(user_msg=msg)
  358. @classmethod
  359. def raise_from(cls):
  360. """Define exception in sub-graph searching module."""
  361. return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls
  362. class GeneratorError(MindConverterException):
  363. """The Generator fail error."""
  364. @unique
  365. class ErrCode(Enum):
  366. """Define error code of SourceFilesSaveError."""
  367. BASE_ERROR = 0
  368. STATEMENT_GENERATION_ERROR = 1
  369. CONVERTED_OPERATOR_LOADING_ERROR = 2
  370. BASE_ERROR_CODE = ConverterErrors.GENERATOR_FAIL.value
  371. ERROR_CODE = ErrCode.BASE_ERROR.value
  372. DEFAULT_MSG = "Error occurred when generate code."
  373. def __init__(self, msg=DEFAULT_MSG):
  374. super(GeneratorError, self).__init__(user_msg=msg)
  375. @classmethod
  376. def raise_from(cls):
  377. """Raise from exceptions below."""
  378. except_source = (ValueError, TypeError, SyntaxError, cls)
  379. return except_source