You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the model lineage python api."""
  16. import os
  17. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamValueError, \
  18. LineageQuerySummaryDataError, LineageParamSummaryPathError, \
  19. LineageQuerierParamException, LineageDirNotExistError, LineageSearchConditionParamError, \
  20. LineageParamTypeError, LineageSummaryParseException
  21. from mindinsight.lineagemgr.common.log import logger as log
  22. from mindinsight.lineagemgr.common.utils import normalize_summary_dir
  23. from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter
  24. from mindinsight.lineagemgr.common.validator.validate import validate_filter_key, validate_search_model_condition, \
  25. validate_condition, validate_path, validate_train_id
  26. from mindinsight.lineagemgr.lineage_parser import LineageParser, LineageOrganizer
  27. from mindinsight.lineagemgr.querier.querier import Querier
  28. from mindinsight.utils.exceptions import MindInsightException
  29. def get_summary_lineage(summary_dir, keys=None):
  30. """
  31. Get the lineage information according to summary directory and keys.
  32. The function queries lineage information of single train process
  33. corresponding to the given summary directory. Users can query the
  34. information according to `keys`.
  35. Args:
  36. summary_dir (str): The summary directory. It contains summary logs for
  37. one training.
  38. keys (list[str]): The filter keys of lineage information. The acceptable
  39. keys are `metric`, `user_defined`, `hyper_parameters`, `algorithm`,
  40. `train_dataset`, `model`, `valid_dataset` and `dataset_graph`.
  41. If it is `None`, all information will be returned. Default: None.
  42. Returns:
  43. dict, the lineage information for one training.
  44. Raises:
  45. LineageParamSummaryPathError: If summary path is invalid.
  46. LineageQuerySummaryDataError: If querying summary data fails.
  47. LineageFileNotFoundError: If the summary log file is not found.
  48. Examples:
  49. >>> summary_dir = "/path/to/summary"
  50. >>> summary_lineage_info = get_summary_lineage(summary_dir)
  51. >>> hyper_parameters = get_summary_lineage(summary_dir, keys=["hyper_parameters"])
  52. """
  53. return general_get_summary_lineage(summary_dir=summary_dir, keys=keys)
  54. def general_get_summary_lineage(data_manager=None, summary_dir=None, keys=None):
  55. """
  56. Get summary lineage from data_manager or parsing from summaries.
  57. One of data_manager or summary_dir needs to be specified. Support getting
  58. super_lineage_obj from data_manager or parsing summaries by summary_dir.
  59. Args:
  60. data_manager (DataManager): Data manager defined as
  61. mindinsight.datavisual.data_transform.data_manager.DataManager
  62. summary_dir (str): The summary directory. It contains summary logs for
  63. one training.
  64. keys (list[str]): The filter keys of lineage information. The acceptable
  65. keys are `metric`, `user_defined`, `hyper_parameters`, `algorithm`,
  66. `train_dataset`, `model`, `valid_dataset` and `dataset_graph`.
  67. If it is `None`, all information will be returned. Default: None.
  68. Returns:
  69. dict, the lineage information for one training.
  70. Raises:
  71. LineageParamSummaryPathError: If summary path is invalid.
  72. LineageQuerySummaryDataError: If querying summary data fails.
  73. LineageFileNotFoundError: If the summary log file is not found.
  74. """
  75. default_result = {}
  76. if data_manager is None and summary_dir is None:
  77. raise LineageParamTypeError("One of data_manager or summary_dir needs to be specified.")
  78. if keys is not None:
  79. validate_filter_key(keys)
  80. if data_manager is None:
  81. normalize_summary_dir(summary_dir)
  82. super_lineage_obj = LineageParser(summary_dir).super_lineage_obj
  83. else:
  84. validate_train_id(summary_dir)
  85. super_lineage_obj = LineageOrganizer(data_manager=data_manager).get_super_lineage_obj(summary_dir)
  86. if super_lineage_obj is None:
  87. return default_result
  88. try:
  89. result = Querier({summary_dir: super_lineage_obj}).get_summary_lineage(summary_dir, keys)
  90. except (LineageQuerierParamException, LineageParamTypeError) as error:
  91. log.error(str(error))
  92. log.exception(error)
  93. raise LineageQuerySummaryDataError("Get summary lineage failed.")
  94. return result[0]
  95. def filter_summary_lineage(summary_base_dir, search_condition=None):
  96. """
  97. Filter the lineage information under summary base directory according to search condition.
  98. Users can filter and sort all lineage information according to the search
  99. condition. The supported filter fields include `summary_dir`, `network`,
  100. etc. The filter conditions include `eq`, `lt`, `gt`, `le`, `ge` and `in`.
  101. If the value type of filter condition is `str`, such as summary_dir and
  102. lineage_type, then its key can only be `in` and `eq`. At the same time,
  103. the combined use of these fields and conditions is supported. If you want
  104. to sort based on filter fields, the field of `sorted_name` and `sorted_type`
  105. should be specified.
  106. Users can use `lineage_type` to decide what kind of lineage information to
  107. query. If the `lineage_type` is not defined, the query result is all lineage
  108. information.
  109. Users can paginate query result based on `offset` and `limit`. The `offset`
  110. refers to page number. The `limit` refers to the number in one page.
  111. Args:
  112. summary_base_dir (str): The summary base directory. It contains summary
  113. directories generated by training.
  114. search_condition (dict): The search condition. When filtering and
  115. sorting, in addition to the following supported fields, fields
  116. prefixed with `metric/` and `user_defined/` are also supported.
  117. For example, the field should be `metric/accuracy` if the key
  118. of `metrics` parameter is `accuracy`. The fields prefixed with
  119. `metric/` and `user_defined/` are related to the `metrics`
  120. parameter in the training script and user defined information in
  121. TrainLineage/EvalLineage callback, respectively. Default: None.
  122. - summary_dir (dict): The filter condition of summary directory.
  123. - loss_function (dict): The filter condition of loss function.
  124. - train_dataset_path (dict): The filter condition of train dataset path.
  125. - train_dataset_count (dict): The filter condition of train dataset count.
  126. - test_dataset_path (dict): The filter condition of test dataset path.
  127. - test_dataset_count (dict): The filter condition of test dataset count.
  128. - network (dict): The filter condition of network.
  129. - optimizer (dict): The filter condition of optimizer.
  130. - learning_rate (dict): The filter condition of learning rate.
  131. - epoch (dict): The filter condition of epoch.
  132. - batch_size (dict): The filter condition of batch size.
  133. - device_num (dict): The filter condition of device num.
  134. - loss (dict): The filter condition of loss.
  135. - model_size (dict): The filter condition of model size.
  136. - dataset_mark (dict): The filter condition of dataset mark.
  137. - lineage_type (dict): The filter condition of lineage type. It decides
  138. what kind of lineage information to query. Its value can be `dataset`
  139. or `model`, e.g., {'in': ['dataset', 'model']}, {'eq': 'model'}, etc.
  140. If its values contain `dataset`, the query result will contain the
  141. lineage information related to data augmentation. If its values contain
  142. `model`, the query result will contain model lineage information.
  143. If it is not defined or it is a dict like {'in': ['dataset', 'model']},
  144. the query result is all lineage information.
  145. - offset (int): Page number, the value range is [0, 100000].
  146. - limit (int): The number in one page, the value range is [1, 100].
  147. - sorted_name (str): Specify which field to sort by.
  148. - sorted_type (str): Specify sort order. It can be `ascending` or
  149. `descending`.
  150. Returns:
  151. dict, lineage information under summary base directory according to
  152. search condition.
  153. Raises:
  154. LineageSearchConditionParamError: If search_condition param is invalid.
  155. LineageParamSummaryPathError: If summary path is invalid.
  156. LineageFileNotFoundError: If the summary log file is not found.
  157. LineageQuerySummaryDataError: If querying summary log file data fails.
  158. Examples:
  159. >>> summary_base_dir = "/path/to/summary_base"
  160. >>> search_condition = {
  161. >>> 'summary_dir': {
  162. >>> 'in': [
  163. >>> os.path.join(summary_base_dir, 'summary_1'),
  164. >>> os.path.join(summary_base_dir, 'summary_2'),
  165. >>> os.path.join(summary_base_dir, 'summary_3')
  166. >>> ]
  167. >>> },
  168. >>> 'loss': {
  169. >>> 'gt': 2.0
  170. >>> },
  171. >>> 'batch_size': {
  172. >>> 'ge': 128,
  173. >>> 'le': 256
  174. >>> },
  175. >>> 'metric/accuracy': {
  176. >>> 'lt': 0.1
  177. >>> },
  178. >>> 'sorted_name': 'summary_dir',
  179. >>> 'sorted_type': 'descending',
  180. >>> 'limit': 3,
  181. >>> 'offset': 0,
  182. >>> 'lineage_type': {
  183. >>> 'eq': 'model'
  184. >>> }
  185. >>> }
  186. >>> summary_lineage = filter_summary_lineage(summary_base_dir)
  187. >>> summary_lineage_filter = filter_summary_lineage(summary_base_dir, search_condition)
  188. """
  189. return general_filter_summary_lineage(summary_base_dir=summary_base_dir, search_condition=search_condition)
  190. def general_filter_summary_lineage(data_manager=None, summary_base_dir=None, search_condition=None, added=False):
  191. """
  192. Filter summary lineage from data_manager or parsing from summaries.
  193. One of data_manager or summary_base_dir needs to be specified. Support getting
  194. super_lineage_obj from data_manager or parsing summaries by summary_base_dir.
  195. Args:
  196. data_manager (DataManager): Data manager defined as
  197. mindinsight.datavisual.data_transform.data_manager.DataManager
  198. summary_base_dir (str): The summary base directory. It contains summary
  199. directories generated by training.
  200. search_condition (dict): The search condition.
  201. """
  202. if data_manager is None and summary_base_dir is None:
  203. raise LineageParamTypeError("One of data_manager or summary_base_dir needs to be specified.")
  204. if data_manager is None:
  205. summary_base_dir = normalize_summary_dir(summary_base_dir)
  206. else:
  207. summary_base_dir = data_manager.summary_base_dir
  208. search_condition = {} if search_condition is None else search_condition
  209. try:
  210. validate_condition(search_condition)
  211. validate_search_model_condition(SearchModelConditionParameter, search_condition)
  212. except MindInsightException as error:
  213. log.error(str(error))
  214. log.exception(error)
  215. raise LineageSearchConditionParamError(str(error.message))
  216. try:
  217. search_condition = _convert_relative_path_to_abspath(summary_base_dir, search_condition)
  218. except (LineageParamValueError, LineageDirNotExistError) as error:
  219. log.error(str(error))
  220. log.exception(error)
  221. raise LineageParamSummaryPathError(str(error.message))
  222. try:
  223. lineage_objects = LineageOrganizer(data_manager, summary_base_dir).super_lineage_objs
  224. result = Querier(lineage_objects).filter_summary_lineage(
  225. condition=search_condition,
  226. added=added
  227. )
  228. except LineageSummaryParseException:
  229. result = {'object': [], 'count': 0}
  230. except (LineageQuerierParamException, LineageParamTypeError) as error:
  231. log.error(str(error))
  232. log.exception(error)
  233. raise LineageQuerySummaryDataError("Filter summary lineage failed.")
  234. return result
  235. def _convert_relative_path_to_abspath(summary_base_dir, search_condition):
  236. """
  237. Convert relative path to absolute path.
  238. Args:
  239. summary_base_dir (str): The summary base directory.
  240. search_condition (dict): The search condition.
  241. Returns:
  242. dict, the updated search_condition.
  243. Raises:
  244. LineageParamValueError: If the value of input_name is invalid.
  245. """
  246. if ("summary_dir" not in search_condition) or (not search_condition.get("summary_dir")):
  247. return search_condition
  248. summary_dir_condition = search_condition.get("summary_dir")
  249. if 'in' in summary_dir_condition:
  250. summary_paths = []
  251. for summary_dir in summary_dir_condition.get('in'):
  252. if summary_dir.startswith('./'):
  253. abs_dir = os.path.join(
  254. summary_base_dir, summary_dir[2:]
  255. )
  256. abs_dir = validate_path(abs_dir)
  257. else:
  258. abs_dir = validate_path(summary_dir)
  259. summary_paths.append(abs_dir)
  260. search_condition.get('summary_dir')['in'] = summary_paths
  261. if 'eq' in summary_dir_condition:
  262. summary_dir = summary_dir_condition.get('eq')
  263. if summary_dir.startswith('./'):
  264. abs_dir = os.path.join(
  265. summary_base_dir, summary_dir[2:]
  266. )
  267. abs_dir = validate_path(abs_dir)
  268. else:
  269. abs_dir = validate_path(summary_dir)
  270. search_condition.get('summary_dir')['eq'] = abs_dir
  271. return search_condition

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。