You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validate.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Validate the parameters."""
  16. import os
  17. from marshmallow import ValidationError
  18. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamMissingError, \
  20. LineageParamTypeError, LineageParamValueError, LineageDirNotExistError
  21. from mindinsight.lineagemgr.common.log import logger as log
  22. from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path
  23. from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
  24. from mindinsight.utils.exceptions import MindInsightException
  25. try:
  26. from mindspore.nn import Cell
  27. from mindspore.train.summary import SummaryRecord
  28. except (ImportError, ModuleNotFoundError):
  29. log.warning('MindSpore Not Found!')
  30. TRAIN_RUN_CONTEXT_ERROR_MAPPING = {
  31. 'optimizer': LineageErrors.PARAM_OPTIMIZER_ERROR,
  32. 'loss_fn': LineageErrors.PARAM_LOSS_FN_ERROR,
  33. 'net_outputs': LineageErrors.PARAM_NET_OUTPUTS_ERROR,
  34. 'train_network': LineageErrors.PARAM_TRAIN_NETWORK_ERROR,
  35. 'train_dataset': LineageErrors.PARAM_DATASET_ERROR,
  36. 'epoch_num': LineageErrors.PARAM_EPOCH_NUM_ERROR,
  37. 'batch_num': LineageErrors.PARAM_BATCH_NUM_ERROR,
  38. 'parallel_mode': LineageErrors.PARAM_TRAIN_PARALLEL_ERROR,
  39. 'device_number': LineageErrors.PARAM_DEVICE_NUMBER_ERROR,
  40. 'list_callback': LineageErrors.PARAM_CALLBACK_LIST_ERROR,
  41. 'train_dataset_size': LineageErrors.PARAM_DATASET_SIZE_ERROR,
  42. }
  43. SEARCH_MODEL_ERROR_MAPPING = {
  44. 'summary_dir': LineageErrors.LINEAGE_PARAM_SUMMARY_DIR_ERROR,
  45. 'loss_function': LineageErrors.LINEAGE_PARAM_LOSS_FUNCTION_ERROR,
  46. 'train_dataset_path': LineageErrors.LINEAGE_PARAM_TRAIN_DATASET_PATH_ERROR,
  47. 'train_dataset_count': LineageErrors.LINEAGE_PARAM_TRAIN_DATASET_COUNT_ERROR,
  48. 'test_dataset_path': LineageErrors.LINEAGE_PARAM_TEST_DATASET_PATH_ERROR,
  49. 'test_dataset_count': LineageErrors.LINEAGE_PARAM_TEST_DATASET_COUNT_ERROR,
  50. 'network': LineageErrors.LINEAGE_PARAM_NETWORK_ERROR,
  51. 'optimizer': LineageErrors.LINEAGE_PARAM_OPTIMIZER_ERROR,
  52. 'learning_rate': LineageErrors.LINEAGE_PARAM_LEARNING_RATE_ERROR,
  53. 'epoch': LineageErrors.LINEAGE_PARAM_EPOCH_ERROR,
  54. 'batch_size': LineageErrors.LINEAGE_PARAM_BATCH_SIZE_ERROR,
  55. 'limit': LineageErrors.PARAM_VALUE_ERROR,
  56. 'offset': LineageErrors.PARAM_VALUE_ERROR,
  57. 'loss': LineageErrors.LINEAGE_PARAM_LOSS_ERROR,
  58. 'model_size': LineageErrors.LINEAGE_PARAM_MODEL_SIZE_ERROR,
  59. 'sorted_name': LineageErrors.LINEAGE_PARAM_SORTED_NAME_ERROR,
  60. 'sorted_type': LineageErrors.LINEAGE_PARAM_SORTED_TYPE_ERROR,
  61. 'dataset_mark': LineageErrors.LINEAGE_PARAM_DATASET_MARK_ERROR,
  62. 'lineage_type': LineageErrors.LINEAGE_PARAM_LINEAGE_TYPE_ERROR
  63. }
  64. TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING = {
  65. 'optimizer': LineageErrorMsg.PARAM_OPTIMIZER_ERROR.value,
  66. 'loss_fn': LineageErrorMsg.PARAM_LOSS_FN_ERROR.value,
  67. 'net_outputs': LineageErrorMsg.PARAM_NET_OUTPUTS_ERROR.value,
  68. 'train_network': LineageErrorMsg.PARAM_TRAIN_NETWORK_ERROR.value,
  69. 'epoch_num': LineageErrorMsg.PARAM_EPOCH_NUM_ERROR.value,
  70. 'batch_num': LineageErrorMsg.PARAM_BATCH_NUM_ERROR.value,
  71. 'parallel_mode': LineageErrorMsg.PARAM_TRAIN_PARALLEL_ERROR.value,
  72. 'device_number': LineageErrorMsg.PARAM_DEVICE_NUMBER_ERROR.value,
  73. 'list_callback': LineageErrorMsg.PARAM_CALLBACK_LIST_ERROR.value
  74. }
  75. SEARCH_MODEL_ERROR_MSG_MAPPING = {
  76. 'summary_dir': LineageErrorMsg.LINEAGE_PARAM_SUMMARY_DIR_ERROR.value,
  77. 'loss_function': LineageErrorMsg.LINEAGE_LOSS_FUNCTION_ERROR.value,
  78. 'train_dataset_path': LineageErrorMsg.LINEAGE_TRAIN_DATASET_PATH_ERROR.value,
  79. 'train_dataset_count': LineageErrorMsg.LINEAGE_TRAIN_DATASET_COUNT_ERROR.value,
  80. 'test_dataset_path': LineageErrorMsg.LINEAGE_TEST_DATASET_PATH_ERROR.value,
  81. 'test_dataset_count': LineageErrorMsg.LINEAGE_TEST_DATASET_COUNT_ERROR.value,
  82. 'network': LineageErrorMsg.LINEAGE_NETWORK_ERROR.value,
  83. 'optimizer': LineageErrorMsg.LINEAGE_OPTIMIZER_ERROR.value,
  84. 'learning_rate': LineageErrorMsg.LINEAGE_LEARNING_RATE_ERROR.value,
  85. 'epoch': LineageErrorMsg.PARAM_EPOCH_NUM_ERROR.value,
  86. 'batch_size': LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value,
  87. 'limit': LineageErrorMsg.PARAM_LIMIT_ERROR.value,
  88. 'offset': LineageErrorMsg.PARAM_OFFSET_ERROR.value,
  89. 'loss': LineageErrorMsg.LINEAGE_LOSS_ERROR.value,
  90. 'model_size': LineageErrorMsg.LINEAGE_MODEL_SIZE_ERROR.value,
  91. 'sorted_name': LineageErrorMsg.LINEAGE_PARAM_SORTED_NAME_ERROR.value,
  92. 'sorted_type': LineageErrorMsg.LINEAGE_PARAM_SORTED_TYPE_ERROR.value,
  93. 'dataset_mark': LineageErrorMsg.LINEAGE_PARAM_DATASET_MARK_ERROR.value,
  94. 'lineage_type': LineageErrorMsg.LINEAGE_PARAM_LINEAGE_TYPE_ERROR.value
  95. }
  96. EVAL_RUN_CONTEXT_ERROR_MAPPING = {
  97. 'valid_dataset': LineageErrors.PARAM_DATASET_ERROR,
  98. 'metrics': LineageErrors.PARAM_EVAL_METRICS_ERROR
  99. }
  100. EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING = {
  101. 'metrics': LineageErrorMsg.PARAM_EVAL_METRICS_ERROR.value,
  102. }
  103. def validate_int_params(int_param, param_name):
  104. """
  105. Verify the parameter which type is integer valid or not.
  106. Args:
  107. int_param (int): parameter that is integer,
  108. including epoch, dataset_batch_size, step_num
  109. param_name (str): the name of parameter,
  110. including epoch, dataset_batch_size, step_num
  111. Raises:
  112. MindInsightException: If the parameters are invalid.
  113. """
  114. if not isinstance(int_param, int) or int_param <= 0 or int_param > pow(2, 63) - 1:
  115. if param_name == 'step_num':
  116. log.error('Invalid step_num. The step number should be a positive integer.')
  117. raise MindInsightException(error=LineageErrors.PARAM_STEP_NUM_ERROR,
  118. message=LineageErrorMsg.PARAM_STEP_NUM_ERROR.value)
  119. if param_name == 'dataset_batch_size':
  120. log.error('Invalid dataset_batch_size. '
  121. 'The batch size should be a positive integer.')
  122. raise MindInsightException(error=LineageErrors.PARAM_BATCH_SIZE_ERROR,
  123. message=LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value)
  124. def validate_network(network):
  125. """
  126. Verify if the network is valid.
  127. Args:
  128. network (Cell): See mindspore.nn.Cell.
  129. Raises:
  130. LineageParamMissingError: If the network is None.
  131. MindInsightException: If the network is invalid.
  132. """
  133. if not network:
  134. error_msg = "The input network for TrainLineage should not be None."
  135. log.error(error_msg)
  136. raise LineageParamMissingError(error_msg)
  137. if not isinstance(network, Cell):
  138. log.error("Invalid network. Network should be an instance"
  139. "of mindspore.nn.Cell.")
  140. raise MindInsightException(
  141. error=LineageErrors.PARAM_TRAIN_NETWORK_ERROR,
  142. message=LineageErrorMsg.PARAM_TRAIN_NETWORK_ERROR.value
  143. )
  144. def validate_file_path(file_path, allow_empty=False):
  145. """
  146. Verify that the file_path is valid.
  147. Args:
  148. file_path (str): Input file path.
  149. allow_empty (bool): Whether file_path can be empty.
  150. Raises:
  151. MindInsightException: If the parameters are invalid.
  152. """
  153. try:
  154. if allow_empty and not file_path:
  155. return
  156. safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None)
  157. except ValidationError as error:
  158. log.error(str(error))
  159. raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR,
  160. message=str(error))
  161. def validate_train_run_context(schema, data):
  162. """
  163. Validate mindspore train run_context data according to schema.
  164. Args:
  165. schema (Schema): data schema.
  166. data (dict): data to check schema.
  167. Raises:
  168. MindInsightException: If the parameters are invalid.
  169. """
  170. errors = schema().validate(data)
  171. for error_key, error_msg in errors.items():
  172. if error_key in TRAIN_RUN_CONTEXT_ERROR_MAPPING.keys():
  173. error_code = TRAIN_RUN_CONTEXT_ERROR_MAPPING.get(error_key)
  174. if TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key):
  175. error_msg = TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key)
  176. log.error(error_msg)
  177. raise MindInsightException(error=error_code, message=error_msg)
  178. def validate_eval_run_context(schema, data):
  179. """
  180. Validate mindspore evaluation job run_context data according to schema.
  181. Args:
  182. schema (Schema): data schema.
  183. data (dict): data to check schema.
  184. Raises:
  185. MindInsightException: If the parameters are invalid.
  186. """
  187. errors = schema().validate(data)
  188. for error_key, error_msg in errors.items():
  189. if error_key in EVAL_RUN_CONTEXT_ERROR_MAPPING.keys():
  190. error_code = EVAL_RUN_CONTEXT_ERROR_MAPPING.get(error_key)
  191. if EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key):
  192. error_msg = EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key)
  193. log.error(error_msg)
  194. raise MindInsightException(error=error_code, message=error_msg)
  195. def validate_search_model_condition(schema, data):
  196. """
  197. Validate search model condition.
  198. Args:
  199. schema (Schema): Data schema.
  200. data (dict): Data to check schema.
  201. Raises:
  202. MindInsightException: If the parameters are invalid.
  203. """
  204. error = schema().validate(data)
  205. for (error_key, error_msgs) in error.items():
  206. if error_key in SEARCH_MODEL_ERROR_MAPPING.keys():
  207. error_code = SEARCH_MODEL_ERROR_MAPPING.get(error_key)
  208. error_msg = SEARCH_MODEL_ERROR_MSG_MAPPING.get(error_key)
  209. for err_msg in error_msgs:
  210. if 'operation' in err_msg.lower():
  211. error_msg = f'The parameter {error_key} is invalid. {err_msg}'
  212. break
  213. log.error(error_msg)
  214. raise MindInsightException(error=error_code, message=error_msg)
  215. def validate_summary_record(summary_record):
  216. """
  217. Validate summary_record.
  218. Args:
  219. summary_record (SummaryRecord): SummaryRecord is used to record
  220. the summary value, and summary_record is an instance of SummaryRecord,
  221. see mindspore.train.summary.SummaryRecord
  222. Raises:
  223. MindInsightException: If the parameters are invalid.
  224. """
  225. if not isinstance(summary_record, SummaryRecord):
  226. log.error("Invalid summary_record. It should be an instance "
  227. "of mindspore.train.summary.SummaryRecord.")
  228. raise MindInsightException(
  229. error=LineageErrors.PARAM_SUMMARY_RECORD_ERROR,
  230. message=LineageErrorMsg.PARAM_SUMMARY_RECORD_ERROR.value
  231. )
  232. def validate_raise_exception(raise_exception):
  233. """
  234. Validate raise_exception.
  235. Args:
  236. raise_exception (bool): decide raise exception or not,
  237. if True, raise exception; else, catch exception and continue.
  238. Raises:
  239. MindInsightException: If the parameters are invalid.
  240. """
  241. if not isinstance(raise_exception, bool):
  242. log.error("Invalid raise_exception. It should be True or False.")
  243. raise MindInsightException(
  244. error=LineageErrors.PARAM_RAISE_EXCEPTION_ERROR,
  245. message=LineageErrorMsg.PARAM_RAISE_EXCEPTION_ERROR.value
  246. )
  247. def validate_filter_key(keys):
  248. """
  249. Verify the keys of filtering is valid or not.
  250. Args:
  251. keys (list): The keys to get the relative lineage info.
  252. Raises:
  253. LineageParamTypeError: If keys is not list.
  254. LineageParamValueError: If the value of keys is invalid.
  255. """
  256. filter_keys = [
  257. 'metric', 'hyper_parameters', 'algorithm',
  258. 'train_dataset', 'model', 'valid_dataset',
  259. 'dataset_graph'
  260. ]
  261. if not isinstance(keys, list):
  262. log.error("Keys must be list.")
  263. raise LineageParamTypeError("Keys must be list.")
  264. for element in keys:
  265. if not isinstance(element, str):
  266. log.error("Element of keys must be str.")
  267. raise LineageParamTypeError("Element of keys must be str.")
  268. if not set(keys).issubset(filter_keys):
  269. err_msg = "Keys must be in {}.".format(filter_keys)
  270. log.error(err_msg)
  271. raise LineageParamValueError(err_msg)
  272. def validate_condition(search_condition):
  273. """
  274. Verify the param in search_condition is valid or not.
  275. Args:
  276. search_condition (dict): The search condition.
  277. Raises:
  278. LineageParamTypeError: If the type of the param in search_condition is invalid.
  279. LineageParamValueError: If the value of the param in search_condition is invalid.
  280. """
  281. if not isinstance(search_condition, dict):
  282. log.error("Invalid search_condition type, it should be dict.")
  283. raise LineageParamTypeError("Invalid search_condition type, "
  284. "it should be dict.")
  285. if "limit" in search_condition:
  286. if isinstance(search_condition.get("limit"), bool) \
  287. or not isinstance(search_condition.get("limit"), int):
  288. log.error("The limit must be int.")
  289. raise LineageParamTypeError("The limit must be int.")
  290. if "offset" in search_condition:
  291. if isinstance(search_condition.get("offset"), bool) \
  292. or not isinstance(search_condition.get("offset"), int):
  293. log.error("The offset must be int.")
  294. raise LineageParamTypeError("The offset must be int.")
  295. if "sorted_name" in search_condition:
  296. sorted_name = search_condition.get("sorted_name")
  297. err_msg = "The sorted_name must be in {} or start with " \
  298. "`metric/` or `user_defined/`.".format(list(FIELD_MAPPING.keys()))
  299. if not isinstance(sorted_name, str):
  300. log.error(err_msg)
  301. raise LineageParamValueError(err_msg)
  302. if not (sorted_name in FIELD_MAPPING
  303. or (sorted_name.startswith('metric/') and len(sorted_name) > 7)
  304. or (sorted_name.startswith('user_defined/') and len(sorted_name) > 13)):
  305. log.error(err_msg)
  306. raise LineageParamValueError(err_msg)
  307. sorted_type_param = ['ascending', 'descending', None]
  308. if "sorted_type" in search_condition:
  309. if "sorted_name" not in search_condition:
  310. log.error("The sorted_name have to exist when sorted_type exists.")
  311. raise LineageParamValueError("The sorted_name have to exist when sorted_type exists.")
  312. if search_condition.get("sorted_type") not in sorted_type_param:
  313. err_msg = "The sorted_type must be ascending or descending."
  314. log.error(err_msg)
  315. raise LineageParamValueError(err_msg)
  316. def validate_path(summary_path):
  317. """
  318. Verify the summary path is valid or not.
  319. Args:
  320. summary_path (str): The summary path which is a dir.
  321. Raises:
  322. LineageParamValueError: If the input param value is invalid.
  323. LineageDirNotExistError: If the summary path is invalid.
  324. """
  325. try:
  326. summary_path = safe_normalize_path(
  327. summary_path, "summary_path", None, check_absolute_path=True
  328. )
  329. except ValidationError:
  330. log.error("The summary path is invalid.")
  331. raise LineageParamValueError("The summary path is invalid.")
  332. if not os.path.isdir(summary_path):
  333. log.error("The summary path does not exist or is not a dir.")
  334. raise LineageDirNotExistError("The summary path does not exist or is not a dir.")
  335. return summary_path
  336. def validate_user_defined_info(user_defined_info):
  337. """
  338. Validate user defined info.
  339. Args:
  340. user_defined_info (dict): The user defined info.
  341. Raises:
  342. LineageParamTypeError: If the type of parameters is invalid.
  343. LineageParamValueError: If user defined keys have been defined in lineage.
  344. """
  345. if not isinstance(user_defined_info, dict):
  346. log.error("Invalid user defined info. It should be a dict.")
  347. raise LineageParamTypeError("Invalid user defined info. It should be dict.")
  348. for key, value in user_defined_info.items():
  349. if not isinstance(key, str):
  350. error_msg = "Dict key type {} is not supported in user defined info." \
  351. "Only str is permitted now.".format(type(key))
  352. log.error(error_msg)
  353. raise LineageParamTypeError(error_msg)
  354. if not isinstance(value, (int, str, float)):
  355. error_msg = "Dict value type {} is not supported in user defined info." \
  356. "Only str, int and float are permitted now.".format(type(value))
  357. log.error(error_msg)
  358. raise LineageParamTypeError(error_msg)
  359. field_map = set(FIELD_MAPPING.keys())
  360. user_defined_keys = set(user_defined_info.keys())
  361. all_keys = field_map | user_defined_keys
  362. if len(field_map) + len(user_defined_keys) != len(all_keys):
  363. raise LineageParamValueError("There are some keys have defined in lineage.")

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。