You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validate.py 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Validate the parameters."""
  16. import os
  17. from marshmallow import ValidationError
  18. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamMissingError, \
  20. LineageParamTypeError, LineageParamValueError, LineageDirNotExistError
  21. from mindinsight.lineagemgr.common.log import logger as log
  22. from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path
  23. from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
  24. from mindinsight.utils.exceptions import MindInsightException, ParamValueError
  25. try:
  26. from mindspore.nn import Cell
  27. from mindspore.train.summary import SummaryRecord
  28. except (ImportError, ModuleNotFoundError):
  29. log.warning('MindSpore Not Found!')
  30. TRAIN_RUN_CONTEXT_ERROR_MAPPING = {
  31. 'optimizer': LineageErrors.PARAM_OPTIMIZER_ERROR,
  32. 'loss_fn': LineageErrors.PARAM_LOSS_FN_ERROR,
  33. 'net_outputs': LineageErrors.PARAM_NET_OUTPUTS_ERROR,
  34. 'train_network': LineageErrors.PARAM_TRAIN_NETWORK_ERROR,
  35. 'train_dataset': LineageErrors.PARAM_DATASET_ERROR,
  36. 'epoch_num': LineageErrors.PARAM_EPOCH_NUM_ERROR,
  37. 'batch_num': LineageErrors.PARAM_BATCH_NUM_ERROR,
  38. 'parallel_mode': LineageErrors.PARAM_TRAIN_PARALLEL_ERROR,
  39. 'device_number': LineageErrors.PARAM_DEVICE_NUMBER_ERROR,
  40. 'list_callback': LineageErrors.PARAM_CALLBACK_LIST_ERROR,
  41. 'train_dataset_size': LineageErrors.PARAM_DATASET_SIZE_ERROR,
  42. }
  43. SEARCH_MODEL_ERROR_MAPPING = {
  44. 'summary_dir': LineageErrors.LINEAGE_PARAM_SUMMARY_DIR_ERROR,
  45. 'loss_function': LineageErrors.LINEAGE_PARAM_LOSS_FUNCTION_ERROR,
  46. 'train_dataset_path': LineageErrors.LINEAGE_PARAM_TRAIN_DATASET_PATH_ERROR,
  47. 'train_dataset_count': LineageErrors.LINEAGE_PARAM_TRAIN_DATASET_COUNT_ERROR,
  48. 'test_dataset_path': LineageErrors.LINEAGE_PARAM_TEST_DATASET_PATH_ERROR,
  49. 'test_dataset_count': LineageErrors.LINEAGE_PARAM_TEST_DATASET_COUNT_ERROR,
  50. 'network': LineageErrors.LINEAGE_PARAM_NETWORK_ERROR,
  51. 'optimizer': LineageErrors.LINEAGE_PARAM_OPTIMIZER_ERROR,
  52. 'learning_rate': LineageErrors.LINEAGE_PARAM_LEARNING_RATE_ERROR,
  53. 'epoch': LineageErrors.LINEAGE_PARAM_EPOCH_ERROR,
  54. 'batch_size': LineageErrors.LINEAGE_PARAM_BATCH_SIZE_ERROR,
  55. 'device_num': LineageErrors.LINEAGE_PARAM_DEVICE_NUM_ERROR,
  56. 'limit': LineageErrors.PARAM_VALUE_ERROR,
  57. 'offset': LineageErrors.PARAM_VALUE_ERROR,
  58. 'loss': LineageErrors.LINEAGE_PARAM_LOSS_ERROR,
  59. 'model_size': LineageErrors.LINEAGE_PARAM_MODEL_SIZE_ERROR,
  60. 'sorted_name': LineageErrors.LINEAGE_PARAM_SORTED_NAME_ERROR,
  61. 'sorted_type': LineageErrors.LINEAGE_PARAM_SORTED_TYPE_ERROR,
  62. 'dataset_mark': LineageErrors.LINEAGE_PARAM_DATASET_MARK_ERROR,
  63. 'lineage_type': LineageErrors.LINEAGE_PARAM_LINEAGE_TYPE_ERROR
  64. }
  65. TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING = {
  66. 'optimizer': LineageErrorMsg.PARAM_OPTIMIZER_ERROR.value,
  67. 'loss_fn': LineageErrorMsg.PARAM_LOSS_FN_ERROR.value,
  68. 'net_outputs': LineageErrorMsg.PARAM_NET_OUTPUTS_ERROR.value,
  69. 'train_network': LineageErrorMsg.PARAM_TRAIN_NETWORK_ERROR.value,
  70. 'epoch_num': LineageErrorMsg.PARAM_EPOCH_NUM_ERROR.value,
  71. 'batch_num': LineageErrorMsg.PARAM_BATCH_NUM_ERROR.value,
  72. 'parallel_mode': LineageErrorMsg.PARAM_TRAIN_PARALLEL_ERROR.value,
  73. 'device_number': LineageErrorMsg.PARAM_DEVICE_NUMBER_ERROR.value,
  74. 'list_callback': LineageErrorMsg.PARAM_CALLBACK_LIST_ERROR.value
  75. }
  76. SEARCH_MODEL_ERROR_MSG_MAPPING = {
  77. 'summary_dir': LineageErrorMsg.LINEAGE_PARAM_SUMMARY_DIR_ERROR.value,
  78. 'loss_function': LineageErrorMsg.LINEAGE_LOSS_FUNCTION_ERROR.value,
  79. 'train_dataset_path': LineageErrorMsg.LINEAGE_TRAIN_DATASET_PATH_ERROR.value,
  80. 'train_dataset_count': LineageErrorMsg.LINEAGE_TRAIN_DATASET_COUNT_ERROR.value,
  81. 'test_dataset_path': LineageErrorMsg.LINEAGE_TEST_DATASET_PATH_ERROR.value,
  82. 'test_dataset_count': LineageErrorMsg.LINEAGE_TEST_DATASET_COUNT_ERROR.value,
  83. 'network': LineageErrorMsg.LINEAGE_NETWORK_ERROR.value,
  84. 'optimizer': LineageErrorMsg.LINEAGE_OPTIMIZER_ERROR.value,
  85. 'learning_rate': LineageErrorMsg.LINEAGE_LEARNING_RATE_ERROR.value,
  86. 'epoch': LineageErrorMsg.PARAM_EPOCH_NUM_ERROR.value,
  87. 'batch_size': LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value,
  88. 'device_num': LineageErrorMsg.PARAM_DEVICE_NUM_ERROR.value,
  89. 'limit': LineageErrorMsg.PARAM_LIMIT_ERROR.value,
  90. 'offset': LineageErrorMsg.PARAM_OFFSET_ERROR.value,
  91. 'loss': LineageErrorMsg.LINEAGE_LOSS_ERROR.value,
  92. 'model_size': LineageErrorMsg.LINEAGE_MODEL_SIZE_ERROR.value,
  93. 'sorted_name': LineageErrorMsg.LINEAGE_PARAM_SORTED_NAME_ERROR.value,
  94. 'sorted_type': LineageErrorMsg.LINEAGE_PARAM_SORTED_TYPE_ERROR.value,
  95. 'dataset_mark': LineageErrorMsg.LINEAGE_PARAM_DATASET_MARK_ERROR.value,
  96. 'lineage_type': LineageErrorMsg.LINEAGE_PARAM_LINEAGE_TYPE_ERROR.value
  97. }
  98. EVAL_RUN_CONTEXT_ERROR_MAPPING = {
  99. 'valid_dataset': LineageErrors.PARAM_DATASET_ERROR,
  100. 'metrics': LineageErrors.PARAM_EVAL_METRICS_ERROR
  101. }
  102. EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING = {
  103. 'metrics': LineageErrorMsg.PARAM_EVAL_METRICS_ERROR.value,
  104. }
  105. def validate_int_params(int_param, param_name):
  106. """
  107. Verify the parameter which type is integer valid or not.
  108. Args:
  109. int_param (int): parameter that is integer,
  110. including epoch, dataset_batch_size, step_num
  111. param_name (str): the name of parameter,
  112. including epoch, dataset_batch_size, step_num
  113. Raises:
  114. MindInsightException: If the parameters are invalid.
  115. """
  116. if not isinstance(int_param, int) or int_param <= 0 or int_param > pow(2, 63) - 1:
  117. if param_name == 'step_num':
  118. log.error('Invalid step_num. The step number should be a positive integer.')
  119. raise MindInsightException(error=LineageErrors.PARAM_STEP_NUM_ERROR,
  120. message=LineageErrorMsg.PARAM_STEP_NUM_ERROR.value)
  121. if param_name == 'dataset_batch_size':
  122. log.error('Invalid dataset_batch_size. '
  123. 'The batch size should be a positive integer.')
  124. raise MindInsightException(error=LineageErrors.PARAM_BATCH_SIZE_ERROR,
  125. message=LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value)
  126. def validate_network(network):
  127. """
  128. Verify if the network is valid.
  129. Args:
  130. network (Cell): See mindspore.nn.Cell.
  131. Raises:
  132. LineageParamMissingError: If the network is None.
  133. MindInsightException: If the network is invalid.
  134. """
  135. if not network:
  136. error_msg = "The input network for TrainLineage should not be None."
  137. log.error(error_msg)
  138. raise LineageParamMissingError(error_msg)
  139. if not isinstance(network, Cell):
  140. log.error("Invalid network. Network should be an instance"
  141. "of mindspore.nn.Cell.")
  142. raise MindInsightException(
  143. error=LineageErrors.PARAM_TRAIN_NETWORK_ERROR,
  144. message=LineageErrorMsg.PARAM_TRAIN_NETWORK_ERROR.value
  145. )
  146. def validate_file_path(file_path, allow_empty=False):
  147. """
  148. Verify that the file_path is valid.
  149. Args:
  150. file_path (str): Input file path.
  151. allow_empty (bool): Whether file_path can be empty.
  152. Raises:
  153. MindInsightException: If the parameters are invalid.
  154. """
  155. try:
  156. if allow_empty and not file_path:
  157. return
  158. safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None)
  159. except ValidationError as error:
  160. log.error(str(error))
  161. raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR,
  162. message=str(error))
  163. def validate_train_run_context(schema, data):
  164. """
  165. Validate mindspore train run_context data according to schema.
  166. Args:
  167. schema (Schema): data schema.
  168. data (dict): data to check schema.
  169. Raises:
  170. MindInsightException: If the parameters are invalid.
  171. """
  172. errors = schema().validate(data)
  173. for error_key, error_msg in errors.items():
  174. if error_key in TRAIN_RUN_CONTEXT_ERROR_MAPPING.keys():
  175. error_code = TRAIN_RUN_CONTEXT_ERROR_MAPPING.get(error_key)
  176. if TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key):
  177. error_msg = TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key)
  178. log.error(error_msg)
  179. raise MindInsightException(error=error_code, message=error_msg)
  180. def validate_eval_run_context(schema, data):
  181. """
  182. Validate mindspore evaluation job run_context data according to schema.
  183. Args:
  184. schema (Schema): data schema.
  185. data (dict): data to check schema.
  186. Raises:
  187. MindInsightException: If the parameters are invalid.
  188. """
  189. errors = schema().validate(data)
  190. for error_key, error_msg in errors.items():
  191. if error_key in EVAL_RUN_CONTEXT_ERROR_MAPPING.keys():
  192. error_code = EVAL_RUN_CONTEXT_ERROR_MAPPING.get(error_key)
  193. if EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key):
  194. error_msg = EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key)
  195. log.error(error_msg)
  196. raise MindInsightException(error=error_code, message=error_msg)
  197. def validate_search_model_condition(schema, data):
  198. """
  199. Validate search model condition.
  200. Args:
  201. schema (Schema): Data schema.
  202. data (dict): Data to check schema.
  203. Raises:
  204. MindInsightException: If the parameters are invalid.
  205. """
  206. error = schema().validate(data)
  207. for (error_key, error_msgs) in error.items():
  208. if error_key in SEARCH_MODEL_ERROR_MAPPING.keys():
  209. error_code = SEARCH_MODEL_ERROR_MAPPING.get(error_key)
  210. error_msg = SEARCH_MODEL_ERROR_MSG_MAPPING.get(error_key)
  211. for err_msg in error_msgs:
  212. if 'operation' in err_msg.lower():
  213. error_msg = f'The parameter {error_key} is invalid. {err_msg}'
  214. break
  215. log.error(error_msg)
  216. raise MindInsightException(error=error_code, message=error_msg)
  217. def validate_summary_record(summary_record):
  218. """
  219. Validate summary_record.
  220. Args:
  221. summary_record (SummaryRecord): SummaryRecord is used to record
  222. the summary value, and summary_record is an instance of SummaryRecord,
  223. see mindspore.train.summary.SummaryRecord
  224. Raises:
  225. MindInsightException: If the parameters are invalid.
  226. """
  227. if not isinstance(summary_record, SummaryRecord):
  228. log.error("Invalid summary_record. It should be an instance "
  229. "of mindspore.train.summary.SummaryRecord.")
  230. raise MindInsightException(
  231. error=LineageErrors.PARAM_SUMMARY_RECORD_ERROR,
  232. message=LineageErrorMsg.PARAM_SUMMARY_RECORD_ERROR.value
  233. )
  234. def validate_raise_exception(raise_exception):
  235. """
  236. Validate raise_exception.
  237. Args:
  238. raise_exception (bool): decide raise exception or not,
  239. if True, raise exception; else, catch exception and continue.
  240. Raises:
  241. MindInsightException: If the parameters are invalid.
  242. """
  243. if not isinstance(raise_exception, bool):
  244. log.error("Invalid raise_exception. It should be True or False.")
  245. raise MindInsightException(
  246. error=LineageErrors.PARAM_RAISE_EXCEPTION_ERROR,
  247. message=LineageErrorMsg.PARAM_RAISE_EXCEPTION_ERROR.value
  248. )
  249. def validate_filter_key(keys):
  250. """
  251. Verify the keys of filtering is valid or not.
  252. Args:
  253. keys (list): The keys to get the relative lineage info.
  254. Raises:
  255. LineageParamTypeError: If keys is not list.
  256. LineageParamValueError: If the value of keys is invalid.
  257. """
  258. filter_keys = [
  259. 'metric', 'hyper_parameters', 'algorithm',
  260. 'train_dataset', 'model', 'valid_dataset',
  261. 'dataset_graph'
  262. ]
  263. if not isinstance(keys, list):
  264. log.error("Keys must be list.")
  265. raise LineageParamTypeError("Keys must be list.")
  266. for element in keys:
  267. if not isinstance(element, str):
  268. log.error("Element of keys must be str.")
  269. raise LineageParamTypeError("Element of keys must be str.")
  270. if not set(keys).issubset(filter_keys):
  271. err_msg = "Keys must be in {}.".format(filter_keys)
  272. log.error(err_msg)
  273. raise LineageParamValueError(err_msg)
  274. def validate_condition(search_condition):
  275. """
  276. Verify the param in search_condition is valid or not.
  277. Args:
  278. search_condition (dict): The search condition.
  279. Raises:
  280. LineageParamTypeError: If the type of the param in search_condition is invalid.
  281. LineageParamValueError: If the value of the param in search_condition is invalid.
  282. """
  283. if not isinstance(search_condition, dict):
  284. log.error("Invalid search_condition type, it should be dict.")
  285. raise LineageParamTypeError("Invalid search_condition type, "
  286. "it should be dict.")
  287. if "limit" in search_condition:
  288. if isinstance(search_condition.get("limit"), bool) \
  289. or not isinstance(search_condition.get("limit"), int):
  290. log.error("The limit must be int.")
  291. raise LineageParamTypeError("The limit must be int.")
  292. if "offset" in search_condition:
  293. if isinstance(search_condition.get("offset"), bool) \
  294. or not isinstance(search_condition.get("offset"), int):
  295. log.error("The offset must be int.")
  296. raise LineageParamTypeError("The offset must be int.")
  297. if "sorted_name" in search_condition:
  298. sorted_name = search_condition.get("sorted_name")
  299. err_msg = "The sorted_name must be in {} or start with " \
  300. "`metric/` or `user_defined/`.".format(list(FIELD_MAPPING.keys()))
  301. if not isinstance(sorted_name, str):
  302. log.error(err_msg)
  303. raise LineageParamValueError(err_msg)
  304. if not (sorted_name in FIELD_MAPPING
  305. or (sorted_name.startswith('metric/') and len(sorted_name) > len('metric/'))
  306. or (sorted_name.startswith('user_defined/') and len(sorted_name) > len('user_defined/'))
  307. or sorted_name in ['tag']):
  308. log.error(err_msg)
  309. raise LineageParamValueError(err_msg)
  310. sorted_type_param = ['ascending', 'descending', None]
  311. if "sorted_type" in search_condition:
  312. if "sorted_name" not in search_condition:
  313. log.error("The sorted_name have to exist when sorted_type exists.")
  314. raise LineageParamValueError("The sorted_name have to exist when sorted_type exists.")
  315. if search_condition.get("sorted_type") not in sorted_type_param:
  316. err_msg = "The sorted_type must be ascending or descending."
  317. log.error(err_msg)
  318. raise LineageParamValueError(err_msg)
  319. def validate_path(summary_path):
  320. """
  321. Verify the summary path is valid or not.
  322. Args:
  323. summary_path (str): The summary path which is a dir.
  324. Raises:
  325. LineageParamValueError: If the input param value is invalid.
  326. LineageDirNotExistError: If the summary path is invalid.
  327. """
  328. try:
  329. summary_path = safe_normalize_path(
  330. summary_path, "summary_path", None, check_absolute_path=True
  331. )
  332. except ValidationError:
  333. log.error("The summary path is invalid.")
  334. raise LineageParamValueError("The summary path is invalid.")
  335. if not os.path.isdir(summary_path):
  336. log.error("The summary path does not exist or is not a dir.")
  337. raise LineageDirNotExistError("The summary path does not exist or is not a dir.")
  338. return summary_path
  339. def validate_user_defined_info(user_defined_info):
  340. """
  341. Validate user defined info.
  342. Args:
  343. user_defined_info (dict): The user defined info.
  344. Raises:
  345. LineageParamTypeError: If the type of parameters is invalid.
  346. LineageParamValueError: If user defined keys have been defined in lineage.
  347. """
  348. if not isinstance(user_defined_info, dict):
  349. log.error("Invalid user defined info. It should be a dict.")
  350. raise LineageParamTypeError("Invalid user defined info. It should be dict.")
  351. for key, value in user_defined_info.items():
  352. if not isinstance(key, str):
  353. error_msg = "Dict key type {} is not supported in user defined info." \
  354. "Only str is permitted now.".format(type(key))
  355. log.error(error_msg)
  356. raise LineageParamTypeError(error_msg)
  357. if not isinstance(value, (int, str, float)):
  358. error_msg = "Dict value type {} is not supported in user defined info." \
  359. "Only str, int and float are permitted now.".format(type(value))
  360. log.error(error_msg)
  361. raise LineageParamTypeError(error_msg)
  362. field_map = set(FIELD_MAPPING.keys())
  363. user_defined_keys = set(user_defined_info.keys())
  364. all_keys = field_map | user_defined_keys
  365. if len(field_map) + len(user_defined_keys) != len(all_keys):
  366. raise LineageParamValueError("There are some keys have defined in lineage.")
  367. def validate_train_id(relative_path):
  368. """
  369. Check if train_id is valid.
  370. Args:
  371. relative_path (str): Train ID of a summary directory, e.g. './log1'.
  372. Returns:
  373. bool, if train id is valid, return True.
  374. """
  375. if not relative_path.startswith('./'):
  376. log.warning("The relative_path does not start with './'.")
  377. raise ParamValueError(
  378. "Summary dir should be relative path starting with './'."
  379. )
  380. if len(relative_path.split("/")) > 2:
  381. log.warning("The relative_path contains multiple '/'.")
  382. raise ParamValueError(
  383. "Summary dir should be relative path starting with './'."
  384. )
  385. def validate_range(name, value, min_value, max_value):
  386. """
  387. Check if value is in [min_value, max_value].
  388. Args:
  389. name (str): Value name.
  390. value (Union[int, float]): Value to be check.
  391. min_value (Union[int, float]): Min value.
  392. max_value (Union[int, float]): Max value.
  393. Raises:
  394. LineageParamValueError, if value type is invalid or value is out of [min_value, max_value].
  395. """
  396. if not isinstance(value, (int, float)):
  397. raise LineageParamValueError("Value should be int or float.")
  398. if value < min_value or value > max_value:
  399. raise LineageParamValueError("The %s should in [%d, %d]." % (name, min_value, max_value))
  400. def validate_added_info(added_info: dict):
  401. """
  402. Check if added_info is valid.
  403. Args:
  404. added_info (dict): The added info.
  405. Raises:
  406. bool, if added_info is valid, return True.
  407. """
  408. added_info_keys = ["tag", "remark"]
  409. if not set(added_info.keys()).issubset(added_info_keys):
  410. err_msg = "Keys must be in {}.".format(added_info_keys)
  411. log.error(err_msg)
  412. raise LineageParamValueError(err_msg)
  413. for key, value in added_info.items():
  414. if key == "tag":
  415. if not isinstance(value, int):
  416. raise LineageParamValueError("'tag' must be int.")
  417. # tag should be in [0, 10].
  418. validate_range("tag", value, min_value=0, max_value=10)
  419. elif key == "remark":
  420. if not isinstance(value, str):
  421. raise LineageParamValueError("'remark' must be str.")
  422. # length of remark should be in [0, 128].
  423. validate_range("length of remark", len(value), min_value=0, max_value=128)

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。