You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validate.py 9.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Validate the parameters."""
  16. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
  17. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, LineageParamValueError
  18. from mindinsight.lineagemgr.common.log import logger as log
  19. from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
  20. from mindinsight.utils.exceptions import MindInsightException, ParamValueError
  21. SEARCH_MODEL_ERROR_MAPPING = {
  22. 'summary_dir': LineageErrors.LINEAGE_PARAM_SUMMARY_DIR_ERROR,
  23. 'loss_function': LineageErrors.LINEAGE_PARAM_LOSS_FUNCTION_ERROR,
  24. 'train_dataset_path': LineageErrors.LINEAGE_PARAM_TRAIN_DATASET_PATH_ERROR,
  25. 'train_dataset_count': LineageErrors.LINEAGE_PARAM_TRAIN_DATASET_COUNT_ERROR,
  26. 'test_dataset_path': LineageErrors.LINEAGE_PARAM_TEST_DATASET_PATH_ERROR,
  27. 'test_dataset_count': LineageErrors.LINEAGE_PARAM_TEST_DATASET_COUNT_ERROR,
  28. 'network': LineageErrors.LINEAGE_PARAM_NETWORK_ERROR,
  29. 'optimizer': LineageErrors.LINEAGE_PARAM_OPTIMIZER_ERROR,
  30. 'learning_rate': LineageErrors.LINEAGE_PARAM_LEARNING_RATE_ERROR,
  31. 'epoch': LineageErrors.LINEAGE_PARAM_EPOCH_ERROR,
  32. 'batch_size': LineageErrors.LINEAGE_PARAM_BATCH_SIZE_ERROR,
  33. 'device_num': LineageErrors.LINEAGE_PARAM_DEVICE_NUM_ERROR,
  34. 'limit': LineageErrors.PARAM_VALUE_ERROR,
  35. 'offset': LineageErrors.PARAM_VALUE_ERROR,
  36. 'loss': LineageErrors.LINEAGE_PARAM_LOSS_ERROR,
  37. 'model_size': LineageErrors.LINEAGE_PARAM_MODEL_SIZE_ERROR,
  38. 'sorted_name': LineageErrors.LINEAGE_PARAM_SORTED_NAME_ERROR,
  39. 'sorted_type': LineageErrors.LINEAGE_PARAM_SORTED_TYPE_ERROR,
  40. 'dataset_mark': LineageErrors.LINEAGE_PARAM_DATASET_MARK_ERROR,
  41. 'lineage_type': LineageErrors.LINEAGE_PARAM_LINEAGE_TYPE_ERROR
  42. }
  43. SEARCH_MODEL_ERROR_MSG_MAPPING = {
  44. 'summary_dir': LineageErrorMsg.LINEAGE_PARAM_SUMMARY_DIR_ERROR.value,
  45. 'loss_function': LineageErrorMsg.LINEAGE_LOSS_FUNCTION_ERROR.value,
  46. 'train_dataset_path': LineageErrorMsg.LINEAGE_TRAIN_DATASET_PATH_ERROR.value,
  47. 'train_dataset_count': LineageErrorMsg.LINEAGE_TRAIN_DATASET_COUNT_ERROR.value,
  48. 'test_dataset_path': LineageErrorMsg.LINEAGE_TEST_DATASET_PATH_ERROR.value,
  49. 'test_dataset_count': LineageErrorMsg.LINEAGE_TEST_DATASET_COUNT_ERROR.value,
  50. 'network': LineageErrorMsg.LINEAGE_NETWORK_ERROR.value,
  51. 'optimizer': LineageErrorMsg.LINEAGE_OPTIMIZER_ERROR.value,
  52. 'learning_rate': LineageErrorMsg.LINEAGE_LEARNING_RATE_ERROR.value,
  53. 'epoch': LineageErrorMsg.PARAM_EPOCH_NUM_ERROR.value,
  54. 'batch_size': LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value,
  55. 'device_num': LineageErrorMsg.PARAM_DEVICE_NUM_ERROR.value,
  56. 'limit': LineageErrorMsg.PARAM_LIMIT_ERROR.value,
  57. 'offset': LineageErrorMsg.PARAM_OFFSET_ERROR.value,
  58. 'loss': LineageErrorMsg.LINEAGE_LOSS_ERROR.value,
  59. 'model_size': LineageErrorMsg.LINEAGE_MODEL_SIZE_ERROR.value,
  60. 'sorted_name': LineageErrorMsg.LINEAGE_PARAM_SORTED_NAME_ERROR.value,
  61. 'sorted_type': LineageErrorMsg.LINEAGE_PARAM_SORTED_TYPE_ERROR.value,
  62. 'dataset_mark': LineageErrorMsg.LINEAGE_PARAM_DATASET_MARK_ERROR.value,
  63. 'lineage_type': LineageErrorMsg.LINEAGE_PARAM_LINEAGE_TYPE_ERROR.value
  64. }
  65. def validate_search_model_condition(schema, data):
  66. """
  67. Validate search model condition.
  68. Args:
  69. schema (Schema): Data schema.
  70. data (dict): Data to check schema.
  71. Raises:
  72. MindInsightException: If the parameters are invalid.
  73. """
  74. error = schema().validate(data)
  75. for (error_key, error_msgs) in error.items():
  76. if error_key in SEARCH_MODEL_ERROR_MAPPING.keys():
  77. error_code = SEARCH_MODEL_ERROR_MAPPING.get(error_key)
  78. error_msg = SEARCH_MODEL_ERROR_MSG_MAPPING.get(error_key)
  79. for err_msg in error_msgs:
  80. if 'operation' in err_msg.lower():
  81. error_msg = f'The parameter {error_key} is invalid. {err_msg}'
  82. break
  83. log.error(error_msg)
  84. raise MindInsightException(error=error_code, message=error_msg)
  85. def validate_condition(search_condition):
  86. """
  87. Verify the param in search_condition is valid or not.
  88. Args:
  89. search_condition (dict): The search condition.
  90. Raises:
  91. LineageParamTypeError: If the type of the param in search_condition is invalid.
  92. LineageParamValueError: If the value of the param in search_condition is invalid.
  93. """
  94. if not isinstance(search_condition, dict):
  95. log.error("Invalid search_condition type, it should be dict.")
  96. raise LineageParamTypeError("Invalid search_condition type, "
  97. "it should be dict.")
  98. if "limit" in search_condition:
  99. if isinstance(search_condition.get("limit"), bool) \
  100. or not isinstance(search_condition.get("limit"), int):
  101. log.error("The limit must be int.")
  102. raise LineageParamTypeError("The limit must be int.")
  103. if "offset" in search_condition:
  104. if isinstance(search_condition.get("offset"), bool) \
  105. or not isinstance(search_condition.get("offset"), int):
  106. log.error("The offset must be int.")
  107. raise LineageParamTypeError("The offset must be int.")
  108. if "sorted_name" in search_condition:
  109. sorted_name = search_condition.get("sorted_name")
  110. err_msg = "The sorted_name must be in {} or start with " \
  111. "`metric/` or `user_defined/`.".format(list(FIELD_MAPPING.keys()))
  112. if not isinstance(sorted_name, str):
  113. log.error(err_msg)
  114. raise LineageParamValueError(err_msg)
  115. if not (sorted_name in FIELD_MAPPING
  116. or (sorted_name.startswith('metric/') and len(sorted_name) > len('metric/'))
  117. or (sorted_name.startswith('user_defined/') and len(sorted_name) > len('user_defined/'))
  118. or sorted_name in ['tag']):
  119. log.error(err_msg)
  120. raise LineageParamValueError(err_msg)
  121. sorted_type_param = ['ascending', 'descending', None]
  122. if "sorted_type" in search_condition:
  123. if "sorted_name" not in search_condition:
  124. log.error("The sorted_name must exist when sorted_type exists.")
  125. raise LineageParamValueError("The sorted_name must exist when sorted_type exists.")
  126. if search_condition.get("sorted_type") not in sorted_type_param:
  127. err_msg = "The sorted_type must be ascending or descending."
  128. log.error(err_msg)
  129. raise LineageParamValueError(err_msg)
  130. def validate_train_id(relative_path):
  131. """
  132. Check if train_id is valid.
  133. Args:
  134. relative_path (str): Train ID of a summary directory, e.g. './log1'.
  135. Returns:
  136. bool, if train id is valid, return True.
  137. """
  138. if not relative_path.startswith('./'):
  139. log.warning("The relative_path does not start with './'.")
  140. raise ParamValueError(
  141. "Summary dir should be relative path starting with './'."
  142. )
  143. if len(relative_path.split("/")) > 2:
  144. log.warning("The relative_path contains multiple '/'.")
  145. raise ParamValueError(
  146. "Summary dir should be relative path starting with './'."
  147. )
  148. def validate_range(name, value, min_value, max_value):
  149. """
  150. Check if value is in [min_value, max_value].
  151. Args:
  152. name (str): Value name.
  153. value (Union[int, float]): Value to be check.
  154. min_value (Union[int, float]): Min value.
  155. max_value (Union[int, float]): Max value.
  156. Raises:
  157. LineageParamValueError, if value type is invalid or value is out of [min_value, max_value].
  158. """
  159. if not isinstance(value, (int, float)):
  160. raise LineageParamValueError("Value should be int or float.")
  161. if value < min_value or value > max_value:
  162. raise LineageParamValueError("The %s should in [%d, %d]." % (name, min_value, max_value))
  163. def validate_added_info(added_info: dict):
  164. """
  165. Check if added_info is valid.
  166. Args:
  167. added_info (dict): The added info.
  168. Raises:
  169. bool, if added_info is valid, return True.
  170. """
  171. added_info_keys = ["tag", "remark"]
  172. if not set(added_info.keys()).issubset(added_info_keys):
  173. err_msg = "Keys of added_info must be in {}.".format(added_info_keys)
  174. raise LineageParamValueError(err_msg)
  175. for key, value in added_info.items():
  176. if key == "tag":
  177. if not isinstance(value, int):
  178. raise LineageParamValueError("'tag' must be int.")
  179. # tag should be in [0, 10].
  180. validate_range("tag", value, min_value=0, max_value=10)
  181. elif key == "remark":
  182. if not isinstance(value, str):
  183. raise LineageParamValueError("'remark' must be str.")
  184. # length of remark should be in [0, 128].
  185. validate_range("length of remark", len(value), min_value=0, max_value=128)