You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tools.py 7.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Common Tools."""
  16. import imghdr
  17. import math
  18. import os
  19. from numbers import Number
  20. from urllib.parse import unquote
  21. from mindinsight.datavisual.common.exceptions import MaxCountExceededError
  22. from mindinsight.datavisual.common.exceptions import PathNotDirectoryError
  23. from mindinsight.datavisual.common.log import logger
  24. from mindinsight.utils import exceptions
  25. from mindinsight.utils.exceptions import UnknownError
  26. _IMG_EXT_TO_MIMETYPE = {
  27. 'bmp': 'image/bmp',
  28. 'gif': 'image/gif',
  29. 'jpeg': 'image/jpeg',
  30. 'png': 'image/png',
  31. }
  32. _DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream'
  33. def find_app_package():
  34. """Find package in current directory."""
  35. backend_dir = os.path.realpath(os.path.join(__file__, os.pardir, os.pardir, os.pardir, "backend"))
  36. packages = []
  37. for file in os.listdir(backend_dir):
  38. file_path = os.path.join(backend_dir, file)
  39. if os.path.isfile(file_path):
  40. continue
  41. if not os.path.isfile(os.path.join(file_path, '__init__.py')):
  42. continue
  43. rel_path = os.path.relpath(file_path, backend_dir)
  44. package = rel_path.replace(os.path.sep, '.')
  45. package = f"mindinsight.backend.{package}"
  46. packages.append(package)
  47. return packages
  48. def to_str(bytes_or_text, encode="utf-8"):
  49. """Bytes transform string."""
  50. if isinstance(bytes_or_text, bytes):
  51. return bytes_or_text.decode(encode)
  52. if isinstance(bytes_or_text, str):
  53. return bytes_or_text
  54. raise TypeError("Param isn't str or bytes type, param={}".format(bytes_or_text))
  55. def to_int(param, param_name):
  56. """
  57. Transfer param to int type.
  58. Args:
  59. param (Any): A param transformed.
  60. param_name (str): Param name.
  61. Returns:
  62. int, value after transformed.
  63. """
  64. try:
  65. param = int(param)
  66. except ValueError:
  67. raise exceptions.ParamTypeError(param_name, 'Integer')
  68. return param
  69. def to_float(param, param_name):
  70. """
  71. Transfer param to float type.
  72. Args:
  73. param (Any): A param transformed.
  74. param_name (str): Param name.
  75. Returns:
  76. float, value after transformed.
  77. """
  78. try:
  79. param = float(param)
  80. except ValueError:
  81. raise exceptions.ParamTypeError(param_name, 'Float')
  82. return param
  83. def str_to_bool(param, param_name):
  84. """
  85. Check param and transform it to bool.
  86. Args:
  87. param (str): 'true' or 'false' is valid.
  88. param_name (str): Param name.
  89. Returns:
  90. bool, if param is 'true', case insensitive.
  91. Raises:
  92. ParamValueError: If the value of param is not 'false' and 'true'.
  93. """
  94. if not isinstance(param, str):
  95. raise exceptions.ParamTypeError(param_name, 'str')
  96. if param.lower() not in ['false', 'true']:
  97. raise exceptions.ParamValueError("The value of %s must be 'false' or 'true'." % param_name)
  98. param = (param.lower() == 'true')
  99. return param
  100. def get_img_mimetype(img_data):
  101. """
  102. Recognize image headers and generate image MIMETYPE.
  103. Args:
  104. img_data (bin): Binary character stream of image.
  105. Returns:
  106. str, a MIMETYPE of the give image.
  107. """
  108. image_type = imghdr.what(None, img_data)
  109. mimetype = _IMG_EXT_TO_MIMETYPE.get(image_type, _DEFAULT_IMAGE_MIMETYPE)
  110. return mimetype
  111. def get_train_id(request):
  112. """
  113. Get train ID from request query string and unquote content.
  114. Args:
  115. request (FlaskRequest): Http request instance.
  116. Returns:
  117. str, unquoted train ID.
  118. """
  119. train_id = request.args.get('train_id')
  120. if train_id is not None:
  121. try:
  122. train_id = unquote(train_id, errors='strict')
  123. except UnicodeDecodeError:
  124. raise exceptions.UrlDecodeError('Unquote train id error with strict mode')
  125. return train_id
  126. def get_profiler_dir(request):
  127. """
  128. Get train ID from request query string and unquote content.
  129. Args:
  130. request (FlaskRequest): Http request instance.
  131. Returns:
  132. str, unquoted train ID.
  133. """
  134. profiler_dir = request.args.get('profile')
  135. if profiler_dir is not None:
  136. try:
  137. profiler_dir = unquote(profiler_dir, errors='strict')
  138. except UnicodeDecodeError:
  139. raise exceptions.UrlDecodeError('Unquote profiler_dir error with strict mode')
  140. return profiler_dir
  141. def unquote_args(request, arg_name):
  142. """
  143. Get args from request query string and unquote content.
  144. Args:
  145. request (FlaskRequest): Http request instance.
  146. arg_name (str): The name of arg.
  147. Returns:
  148. str, unquoted arg.
  149. """
  150. arg_value = request.args.get(arg_name, "")
  151. if arg_value is not None:
  152. try:
  153. arg_value = unquote(arg_value, errors='strict')
  154. except UnicodeDecodeError:
  155. raise exceptions.ParamValueError('Unquote error with strict mode')
  156. return arg_value
  157. def get_device_id(request):
  158. """
  159. Get device ID from request query string and unquote content.
  160. Args:
  161. request (FlaskRequest): Http request instance.
  162. Returns:
  163. str, unquoted device ID.
  164. """
  165. device_id = request.args.get('device_id')
  166. if device_id is not None:
  167. try:
  168. device_id = unquote(device_id, errors='strict')
  169. except UnicodeDecodeError:
  170. raise exceptions.UrlDecodeError('Unquote train id error with strict mode')
  171. else:
  172. device_id = "0"
  173. return device_id
  174. def if_nan_inf_to_none(name, value):
  175. """
  176. Transform value to None if it is NaN or Inf.
  177. Args:
  178. name (str): Name of value.
  179. value (float): A number transformed.
  180. Returns:
  181. float, if value is NaN or Inf, return None.
  182. """
  183. if not isinstance(value, Number):
  184. raise exceptions.ParamTypeError(name, 'number')
  185. if math.isnan(value) or math.isinf(value):
  186. value = None
  187. return value
  188. def exception_wrapper(func):
  189. """Exception wrapper"""
  190. def wrapper(*args, **kwargs):
  191. try:
  192. return func(*args, **kwargs)
  193. except (PathNotDirectoryError, FileNotFoundError) as err:
  194. # except PathNotDirectoryError and FileNotFoundError as they are on warning level
  195. logger.warning(str(err))
  196. except Exception as exc:
  197. logger.exception(exc)
  198. raise UnknownError(str(exc))
  199. return wrapper
  200. def exception_no_raise_wrapper(func):
  201. """Don't raise exception to avoid printing error in stdout and log error in the log file."""
  202. def wrapper(*args, **kwargs):
  203. try:
  204. return exception_wrapper(func)(*args, **kwargs)
  205. except UnknownError as err:
  206. logger.error(str(err))
  207. return wrapper
  208. class Counter:
  209. """Count accumulator with limit checking."""
  210. def __init__(self, max_count=None, init_count=0):
  211. self._count = init_count
  212. self._max_count = max_count
  213. def add(self, value=1):
  214. """Add value."""
  215. if self._max_count is not None and self._count + value > self._max_count:
  216. raise MaxCountExceededError()
  217. self._count += value