You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tools.py 4.3 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Common Tools."""
  16. import imghdr
  17. import math
  18. import os
  19. from numbers import Number
  20. from urllib.parse import unquote
  21. from mindinsight.utils import exceptions
  22. _IMG_EXT_TO_MIMETYPE = {
  23. 'bmp': 'image/bmp',
  24. 'gif': 'image/gif',
  25. 'jpeg': 'image/jpeg',
  26. 'png': 'image/png',
  27. }
  28. _DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream'
  29. def find_app_package():
  30. """Find package in current directory."""
  31. backend_dir = os.path.realpath(os.path.join(__file__, os.pardir, os.pardir, os.pardir, "backend"))
  32. packages = []
  33. for file in os.listdir(backend_dir):
  34. file_path = os.path.join(backend_dir, file)
  35. if os.path.isfile(file_path):
  36. continue
  37. if not os.path.isfile(os.path.join(file_path, '__init__.py')):
  38. continue
  39. rel_path = os.path.relpath(file_path, backend_dir)
  40. package = rel_path.replace(os.path.sep, '.')
  41. package = f"mindinsight.backend.{package}"
  42. packages.append(package)
  43. return packages
  44. def to_str(bytes_or_text, encode="utf-8"):
  45. """Bytes transform string."""
  46. if isinstance(bytes_or_text, bytes):
  47. return bytes_or_text.decode(encode)
  48. if isinstance(bytes_or_text, str):
  49. return bytes_or_text
  50. raise TypeError("Param isn't str or bytes type, param={}".format(bytes_or_text))
  51. def to_int(param, param_name):
  52. """
  53. Transfer param to int type.
  54. Args:
  55. param (Any): A param transformed.
  56. param_name (str): Param name.
  57. Returns:
  58. int, value after transformed.
  59. """
  60. try:
  61. param = int(param)
  62. except ValueError:
  63. raise exceptions.ParamTypeError(param_name, 'Integer')
  64. return param
  65. def str_to_bool(param, param_name):
  66. """
  67. Check param and transform it to bool.
  68. Args:
  69. param (str): 'true' or 'false' is valid.
  70. param_name (str): Param name.
  71. Returns:
  72. bool, if param is 'true', case insensitive.
  73. Raises:
  74. ParamValueError: If the value of param is not 'false' and 'true'.
  75. """
  76. if not isinstance(param, str):
  77. raise exceptions.ParamTypeError(param_name, 'str')
  78. if param.lower() not in ['false', 'true']:
  79. raise exceptions.ParamValueError("The value of %s must be 'false' or 'true'." % param_name)
  80. param = (param.lower() == 'true')
  81. return param
  82. def get_img_mimetype(img_data):
  83. """
  84. Recognize image headers and generate image MIMETYPE.
  85. Args:
  86. img_data (bin): Binary character stream of image.
  87. Returns:
  88. str, a MIMETYPE of the give image.
  89. """
  90. image_type = imghdr.what(None, img_data)
  91. mimetype = _IMG_EXT_TO_MIMETYPE.get(image_type, _DEFAULT_IMAGE_MIMETYPE)
  92. return mimetype
  93. def get_train_id(request):
  94. """
  95. Get train ID from requst query string and unquote content.
  96. Args:
  97. request (FlaskRequest): Http request instance.
  98. Returns:
  99. str, unquoted train ID.
  100. """
  101. train_id = request.args.get('train_id')
  102. if train_id is not None:
  103. try:
  104. train_id = unquote(train_id, errors='strict')
  105. except UnicodeDecodeError:
  106. raise exceptions.ParamValueError('Unquote error with strict mode')
  107. return train_id
  108. def if_nan_inf_to_none(name, value):
  109. """
  110. Transform value to None if it is NaN or Inf.
  111. Args:
  112. name (str): Name of value.
  113. value (float): A number transformed.
  114. Returns:
  115. float, if value is NaN or Inf, return None.
  116. """
  117. if not isinstance(value, Number):
  118. raise exceptions.ParamTypeError(name, 'number')
  119. if math.isnan(value) or math.isinf(value):
  120. value = None
  121. return value

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。

Contributors (1)