You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validate.py 16 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Validate the parameters."""
  16. import os
  17. from marshmallow import ValidationError
  18. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamMissingError, \
  20. LineageParamTypeError, LineageParamValueError, LineageDirNotExistError
  21. from mindinsight.lineagemgr.common.log import logger as log
  22. from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path
  23. from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
  24. from mindinsight.utils.exceptions import MindInsightException
  25. try:
  26. from mindspore.nn import Cell
  27. from mindspore.train.summary import SummaryRecord
  28. except (ImportError, ModuleNotFoundError):
  29. log.warning('MindSpore Not Found!')
  30. TRAIN_RUN_CONTEXT_ERROR_MAPPING = {
  31. 'optimizer': LineageErrors.PARAM_OPTIMIZER_ERROR,
  32. 'loss_fn': LineageErrors.PARAM_LOSS_FN_ERROR,
  33. 'net_outputs': LineageErrors.PARAM_NET_OUTPUTS_ERROR,
  34. 'train_network': LineageErrors.PARAM_TRAIN_NETWORK_ERROR,
  35. 'train_dataset': LineageErrors.PARAM_DATASET_ERROR,
  36. 'epoch_num': LineageErrors.PARAM_EPOCH_NUM_ERROR,
  37. 'batch_num': LineageErrors.PARAM_BATCH_NUM_ERROR,
  38. 'parallel_mode': LineageErrors.PARAM_TRAIN_PARALLEL_ERROR,
  39. 'device_number': LineageErrors.PARAM_DEVICE_NUMBER_ERROR,
  40. 'list_callback': LineageErrors.PARAM_CALLBACK_LIST_ERROR,
  41. 'train_dataset_size': LineageErrors.PARAM_DATASET_SIZE_ERROR,
  42. }
  43. SEARCH_MODEL_ERROR_MAPPING = {
  44. 'summary_dir': LineageErrors.LINEAGE_PARAM_SUMMARY_DIR_ERROR,
  45. 'loss_function': LineageErrors.LINEAGE_PARAM_LOSS_FUNCTION_ERROR,
  46. 'train_dataset_path': LineageErrors.LINEAGE_PARAM_TRAIN_DATASET_PATH_ERROR,
  47. 'train_dataset_count': LineageErrors.LINEAGE_PARAM_TRAIN_DATASET_COUNT_ERROR,
  48. 'test_dataset_path': LineageErrors.LINEAGE_PARAM_TEST_DATASET_PATH_ERROR,
  49. 'test_dataset_count': LineageErrors.LINEAGE_PARAM_TEST_DATASET_COUNT_ERROR,
  50. 'network': LineageErrors.LINEAGE_PARAM_NETWORK_ERROR,
  51. 'optimizer': LineageErrors.LINEAGE_PARAM_OPTIMIZER_ERROR,
  52. 'learning_rate': LineageErrors.LINEAGE_PARAM_LEARNING_RATE_ERROR,
  53. 'epoch': LineageErrors.LINEAGE_PARAM_EPOCH_ERROR,
  54. 'batch_size': LineageErrors.LINEAGE_PARAM_BATCH_SIZE_ERROR,
  55. 'limit': LineageErrors.PARAM_VALUE_ERROR,
  56. 'offset': LineageErrors.PARAM_VALUE_ERROR,
  57. 'loss': LineageErrors.LINEAGE_PARAM_LOSS_ERROR,
  58. 'model_size': LineageErrors.LINEAGE_PARAM_MODEL_SIZE_ERROR,
  59. 'sorted_name': LineageErrors.LINEAGE_PARAM_SORTED_NAME_ERROR,
  60. 'sorted_type': LineageErrors.LINEAGE_PARAM_SORTED_TYPE_ERROR,
  61. 'lineage_type': LineageErrors.LINEAGE_PARAM_LINEAGE_TYPE_ERROR
  62. }
  63. TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING = {
  64. 'optimizer': LineageErrorMsg.PARAM_OPTIMIZER_ERROR.value,
  65. 'loss_fn': LineageErrorMsg.PARAM_LOSS_FN_ERROR.value,
  66. 'net_outputs': LineageErrorMsg.PARAM_NET_OUTPUTS_ERROR.value,
  67. 'train_network': LineageErrorMsg.PARAM_TRAIN_NETWORK_ERROR.value,
  68. 'epoch_num': LineageErrorMsg.PARAM_EPOCH_NUM_ERROR.value,
  69. 'batch_num': LineageErrorMsg.PARAM_BATCH_NUM_ERROR.value,
  70. 'parallel_mode': LineageErrorMsg.PARAM_TRAIN_PARALLEL_ERROR.value,
  71. 'device_number': LineageErrorMsg.PARAM_DEVICE_NUMBER_ERROR.value,
  72. 'list_callback': LineageErrorMsg.PARAM_CALLBACK_LIST_ERROR.value
  73. }
  74. SEARCH_MODEL_ERROR_MSG_MAPPING = {
  75. 'summary_dir': LineageErrorMsg.LINEAGE_PARAM_SUMMARY_DIR_ERROR.value,
  76. 'loss_function': LineageErrorMsg.LINEAGE_LOSS_FUNCTION_ERROR.value,
  77. 'train_dataset_path': LineageErrorMsg.LINEAGE_TRAIN_DATASET_PATH_ERROR.value,
  78. 'train_dataset_count': LineageErrorMsg.LINEAGE_TRAIN_DATASET_COUNT_ERROR.value,
  79. 'test_dataset_path': LineageErrorMsg.LINEAGE_TEST_DATASET_PATH_ERROR.value,
  80. 'test_dataset_count': LineageErrorMsg.LINEAGE_TEST_DATASET_COUNT_ERROR.value,
  81. 'network': LineageErrorMsg.LINEAGE_NETWORK_ERROR.value,
  82. 'optimizer': LineageErrorMsg.LINEAGE_OPTIMIZER_ERROR.value,
  83. 'learning_rate': LineageErrorMsg.LINEAGE_LEARNING_RATE_ERROR.value,
  84. 'epoch': LineageErrorMsg.PARAM_EPOCH_NUM_ERROR.value,
  85. 'batch_size': LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value,
  86. 'limit': LineageErrorMsg.PARAM_LIMIT_ERROR.value,
  87. 'offset': LineageErrorMsg.PARAM_OFFSET_ERROR.value,
  88. 'loss': LineageErrorMsg.LINEAGE_LOSS_ERROR.value,
  89. 'model_size': LineageErrorMsg.LINEAGE_MODEL_SIZE_ERROR.value,
  90. 'sorted_name': LineageErrorMsg.LINEAGE_PARAM_SORTED_NAME_ERROR.value,
  91. 'sorted_type': LineageErrorMsg.LINEAGE_PARAM_SORTED_TYPE_ERROR.value,
  92. 'lineage_type': LineageErrorMsg.LINEAGE_PARAM_LINEAGE_TYPE_ERROR.value
  93. }
  94. EVAL_RUN_CONTEXT_ERROR_MAPPING = {
  95. 'valid_dataset': LineageErrors.PARAM_DATASET_ERROR,
  96. 'metrics': LineageErrors.PARAM_EVAL_METRICS_ERROR
  97. }
  98. EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING = {
  99. 'metrics': LineageErrorMsg.PARAM_EVAL_METRICS_ERROR.value,
  100. }
  101. def validate_int_params(int_param, param_name):
  102. """
  103. Verify the parameter which type is integer valid or not.
  104. Args:
  105. int_param (int): parameter that is integer,
  106. including epoch, dataset_batch_size, step_num
  107. param_name (str): the name of parameter,
  108. including epoch, dataset_batch_size, step_num
  109. Raises:
  110. MindInsightException: If the parameters are invalid.
  111. """
  112. if not isinstance(int_param, int) or int_param <= 0 or int_param > pow(2, 63) - 1:
  113. if param_name == 'step_num':
  114. log.error('Invalid step_num. The step number should be a positive integer.')
  115. raise MindInsightException(error=LineageErrors.PARAM_STEP_NUM_ERROR,
  116. message=LineageErrorMsg.PARAM_STEP_NUM_ERROR.value)
  117. if param_name == 'dataset_batch_size':
  118. log.error('Invalid dataset_batch_size. '
  119. 'The batch size should be a positive integer.')
  120. raise MindInsightException(error=LineageErrors.PARAM_BATCH_SIZE_ERROR,
  121. message=LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value)
  122. def validate_network(network):
  123. """
  124. Verify if the network is valid.
  125. Args:
  126. network (Cell): See mindspore.nn.Cell.
  127. Raises:
  128. LineageParamMissingError: If the network is None.
  129. MindInsightException: If the network is invalid.
  130. """
  131. if not network:
  132. error_msg = "The input network for TrainLineage should not be None."
  133. log.error(error_msg)
  134. raise LineageParamMissingError(error_msg)
  135. if not isinstance(network, Cell):
  136. log.error("Invalid network. Network should be an instance"
  137. "of mindspore.nn.Cell.")
  138. raise MindInsightException(
  139. error=LineageErrors.PARAM_TRAIN_NETWORK_ERROR,
  140. message=LineageErrorMsg.PARAM_TRAIN_NETWORK_ERROR.value
  141. )
  142. def validate_file_path(file_path, allow_empty=False):
  143. """
  144. Verify that the file_path is valid.
  145. Args:
  146. file_path (str): Input file path.
  147. allow_empty (bool): Whether file_path can be empty.
  148. Raises:
  149. MindInsightException: If the parameters are invalid.
  150. """
  151. try:
  152. if allow_empty and not file_path:
  153. return
  154. safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None)
  155. except ValidationError as error:
  156. log.error(str(error))
  157. raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR,
  158. message=str(error))
  159. def validate_train_run_context(schema, data):
  160. """
  161. Validate mindspore train run_context data according to schema.
  162. Args:
  163. schema (Schema): data schema.
  164. data (dict): data to check schema.
  165. Raises:
  166. MindInsightException: If the parameters are invalid.
  167. """
  168. errors = schema().validate(data)
  169. for error_key, error_msg in errors.items():
  170. if error_key in TRAIN_RUN_CONTEXT_ERROR_MAPPING.keys():
  171. error_code = TRAIN_RUN_CONTEXT_ERROR_MAPPING.get(error_key)
  172. if TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key):
  173. error_msg = TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key)
  174. log.error(error_msg)
  175. raise MindInsightException(error=error_code, message=error_msg)
  176. def validate_eval_run_context(schema, data):
  177. """
  178. Validate mindspore evaluation job run_context data according to schema.
  179. Args:
  180. schema (Schema): data schema.
  181. data (dict): data to check schema.
  182. Raises:
  183. MindInsightException: If the parameters are invalid.
  184. """
  185. errors = schema().validate(data)
  186. for error_key, error_msg in errors.items():
  187. if error_key in EVAL_RUN_CONTEXT_ERROR_MAPPING.keys():
  188. error_code = EVAL_RUN_CONTEXT_ERROR_MAPPING.get(error_key)
  189. if EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key):
  190. error_msg = EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key)
  191. log.error(error_msg)
  192. raise MindInsightException(error=error_code, message=error_msg)
  193. def validate_search_model_condition(schema, data):
  194. """
  195. Validate search model condition.
  196. Args:
  197. schema (Schema): Data schema.
  198. data (dict): Data to check schema.
  199. Raises:
  200. MindInsightException: If the parameters are invalid.
  201. """
  202. error = schema().validate(data)
  203. for error_key in error.keys():
  204. if error_key in SEARCH_MODEL_ERROR_MAPPING.keys():
  205. error_code = SEARCH_MODEL_ERROR_MAPPING.get(error_key)
  206. error_msg = SEARCH_MODEL_ERROR_MSG_MAPPING.get(error_key)
  207. log.error(error_msg)
  208. raise MindInsightException(error=error_code, message=error_msg)
  209. def validate_summary_record(summary_record):
  210. """
  211. Validate summary_record.
  212. Args:
  213. summary_record (SummaryRecord): SummaryRecord is used to record
  214. the summary value, and summary_record is an instance of SummaryRecord,
  215. see mindspore.train.summary.SummaryRecord
  216. Raises:
  217. MindInsightException: If the parameters are invalid.
  218. """
  219. if not isinstance(summary_record, SummaryRecord):
  220. log.error("Invalid summary_record. It should be an instance "
  221. "of mindspore.train.summary.SummaryRecord.")
  222. raise MindInsightException(
  223. error=LineageErrors.PARAM_SUMMARY_RECORD_ERROR,
  224. message=LineageErrorMsg.PARAM_SUMMARY_RECORD_ERROR.value
  225. )
  226. def validate_raise_exception(raise_exception):
  227. """
  228. Validate raise_exception.
  229. Args:
  230. raise_exception (bool): decide raise exception or not,
  231. if True, raise exception; else, catch exception and continue.
  232. Raises:
  233. MindInsightException: If the parameters are invalid.
  234. """
  235. if not isinstance(raise_exception, bool):
  236. log.error("Invalid raise_exception. It should be True or False.")
  237. raise MindInsightException(
  238. error=LineageErrors.PARAM_RAISE_EXCEPTION_ERROR,
  239. message=LineageErrorMsg.PARAM_RAISE_EXCEPTION_ERROR.value
  240. )
  241. def validate_filter_key(keys):
  242. """
  243. Verify the keys of filtering is valid or not.
  244. Args:
  245. keys (list): The keys to get the relative lineage info.
  246. Raises:
  247. LineageParamTypeError: If keys is not list.
  248. LineageParamValueError: If the value of keys is invalid.
  249. """
  250. filter_keys = [
  251. 'metric', 'hyper_parameters', 'algorithm',
  252. 'train_dataset', 'model', 'valid_dataset',
  253. 'dataset_graph'
  254. ]
  255. if not isinstance(keys, list):
  256. log.error("Keys must be list.")
  257. raise LineageParamTypeError("Keys must be list.")
  258. for element in keys:
  259. if not isinstance(element, str):
  260. log.error("Element of keys must be str.")
  261. raise LineageParamTypeError("Element of keys must be str.")
  262. if not set(keys).issubset(filter_keys):
  263. err_msg = "Keys must be in {}.".format(filter_keys)
  264. log.error(err_msg)
  265. raise LineageParamValueError(err_msg)
  266. def validate_condition(search_condition):
  267. """
  268. Verify the param in search_condition is valid or not.
  269. Args:
  270. search_condition (dict): The search condition.
  271. Raises:
  272. LineageParamTypeError: If the type of the param in search_condition is invalid.
  273. LineageParamValueError: If the value of the param in search_condition is invalid.
  274. """
  275. if not isinstance(search_condition, dict):
  276. log.error("Invalid search_condition type, it should be dict.")
  277. raise LineageParamTypeError("Invalid search_condition type, "
  278. "it should be dict.")
  279. if "limit" in search_condition:
  280. if isinstance(search_condition.get("limit"), bool) \
  281. or not isinstance(search_condition.get("limit"), int):
  282. log.error("The limit must be int.")
  283. raise LineageParamTypeError("The limit must be int.")
  284. if "offset" in search_condition:
  285. if isinstance(search_condition.get("offset"), bool) \
  286. or not isinstance(search_condition.get("offset"), int):
  287. log.error("The offset must be int.")
  288. raise LineageParamTypeError("The offset must be int.")
  289. if "sorted_name" in search_condition:
  290. sorted_name = search_condition.get("sorted_name")
  291. err_msg = "The sorted_name must be in {} or start with " \
  292. "`metric_`.".format(list(FIELD_MAPPING.keys()))
  293. if not isinstance(sorted_name, str):
  294. log.error(err_msg)
  295. raise LineageParamValueError(err_msg)
  296. if sorted_name not in FIELD_MAPPING and not (
  297. sorted_name.startswith('metric_') and len(sorted_name) > 7):
  298. log.error(err_msg)
  299. raise LineageParamValueError(err_msg)
  300. sorted_type_param = ['ascending', 'descending', None]
  301. if "sorted_type" in search_condition:
  302. if "sorted_name" not in search_condition:
  303. log.error("The sorted_name have to exist when sorted_type exists.")
  304. raise LineageParamValueError("The sorted_name have to exist when sorted_type exists.")
  305. if search_condition.get("sorted_type") not in sorted_type_param:
  306. err_msg = "The sorted_type must be ascending or descending."
  307. log.error(err_msg)
  308. raise LineageParamValueError(err_msg)
  309. def validate_path(summary_path):
  310. """
  311. Verify the summary path is valid or not.
  312. Args:
  313. summary_path (str): The summary path which is a dir.
  314. Raises:
  315. LineageParamValueError: If the input param value is invalid.
  316. LineageDirNotExistError: If the summary path is invalid.
  317. """
  318. try:
  319. summary_path = safe_normalize_path(
  320. summary_path, "summary_path", None, check_absolute_path=True
  321. )
  322. except ValidationError:
  323. log.error("The summary path is invalid.")
  324. raise LineageParamValueError("The summary path is invalid.")
  325. if not os.path.isdir(summary_path):
  326. log.error("The summary path does not exist or is not a dir.")
  327. raise LineageDirNotExistError("The summary path does not exist or is not a dir.")
  328. return summary_path

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。

Contributors (1)