You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the model lineage python api."""
  16. import os
  17. import numpy as np
  18. import pandas as pd
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamValueError, \
  20. LineageQuerySummaryDataError, LineageParamSummaryPathError, \
  21. LineageQuerierParamException, LineageDirNotExistError, LineageSearchConditionParamError, \
  22. LineageParamTypeError, LineageSummaryParseException
  23. from mindinsight.lineagemgr.common.log import logger as log
  24. from mindinsight.lineagemgr.common.utils import normalize_summary_dir, get_relative_path
  25. from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter
  26. from mindinsight.lineagemgr.common.validator.validate import validate_filter_key, validate_search_model_condition, \
  27. validate_condition, validate_path, validate_train_id
  28. from mindinsight.lineagemgr.lineage_parser import LineageParser, LineageOrganizer
  29. from mindinsight.lineagemgr.querier.querier import Querier
  30. from mindinsight.optimizer.common.enums import ReasonCode
  31. from mindinsight.optimizer.utils.utils import is_simple_numpy_number
  32. from mindinsight.utils.exceptions import MindInsightException
  33. _METRIC_PREFIX = "[M]"
  34. _USER_DEFINED_PREFIX = "[U]"
  35. USER_DEFINED_INFO_LIMIT = 100
  36. def get_summary_lineage(data_manager=None, summary_dir=None, keys=None):
  37. """
  38. Get summary lineage from data_manager or parsing from summaries.
  39. One of data_manager or summary_dir needs to be specified. Support getting
  40. super_lineage_obj from data_manager or parsing summaries by summary_dir.
  41. Args:
  42. data_manager (DataManager): Data manager defined as
  43. mindinsight.datavisual.data_transform.data_manager.DataManager
  44. summary_dir (str): The summary directory. It contains summary logs for
  45. one training.
  46. keys (list[str]): The filter keys of lineage information. The acceptable
  47. keys are `metric`, `user_defined`, `hyper_parameters`, `algorithm`,
  48. `train_dataset`, `model`, `valid_dataset` and `dataset_graph`.
  49. If it is `None`, all information will be returned. Default: None.
  50. Returns:
  51. dict, the lineage information for one training.
  52. Raises:
  53. LineageParamSummaryPathError: If summary path is invalid.
  54. LineageQuerySummaryDataError: If querying summary data fails.
  55. LineageFileNotFoundError: If the summary log file is not found.
  56. """
  57. default_result = {}
  58. if data_manager is None and summary_dir is None:
  59. raise LineageParamTypeError("One of data_manager or summary_dir needs to be specified.")
  60. if data_manager is not None and summary_dir is None:
  61. raise LineageParamTypeError("If data_manager is specified, the summary_dir needs to be "
  62. "specified as relative path.")
  63. if keys is not None:
  64. validate_filter_key(keys)
  65. if data_manager is None:
  66. normalize_summary_dir(summary_dir)
  67. super_lineage_obj = LineageParser(summary_dir).super_lineage_obj
  68. else:
  69. validate_train_id(summary_dir)
  70. super_lineage_obj = LineageOrganizer(data_manager=data_manager).get_super_lineage_obj(summary_dir)
  71. if super_lineage_obj is None:
  72. return default_result
  73. try:
  74. result = Querier({summary_dir: super_lineage_obj}).get_summary_lineage(summary_dir, keys)
  75. except (LineageQuerierParamException, LineageParamTypeError) as error:
  76. log.error(str(error))
  77. log.exception(error)
  78. raise LineageQuerySummaryDataError("Get summary lineage failed.")
  79. return result[0]
  80. def filter_summary_lineage(data_manager=None, summary_base_dir=None, search_condition=None, added=False):
  81. """
  82. Filter summary lineage from data_manager or parsing from summaries.
  83. One of data_manager or summary_base_dir needs to be specified. Support getting
  84. super_lineage_obj from data_manager or parsing summaries by summary_base_dir.
  85. Args:
  86. data_manager (DataManager): Data manager defined as
  87. mindinsight.datavisual.data_transform.data_manager.DataManager
  88. summary_base_dir (str): The summary base directory. It contains summary
  89. directories generated by training.
  90. search_condition (dict): The search condition.
  91. """
  92. if data_manager is None and summary_base_dir is None:
  93. raise LineageParamTypeError("One of data_manager or summary_base_dir needs to be specified.")
  94. if data_manager is None:
  95. summary_base_dir = normalize_summary_dir(summary_base_dir)
  96. else:
  97. summary_base_dir = data_manager.summary_base_dir
  98. search_condition = {} if search_condition is None else search_condition
  99. try:
  100. validate_condition(search_condition)
  101. validate_search_model_condition(SearchModelConditionParameter, search_condition)
  102. except MindInsightException as error:
  103. log.error(str(error))
  104. log.exception(error)
  105. raise LineageSearchConditionParamError(str(error.message))
  106. try:
  107. search_condition = _convert_relative_path_to_abspath(summary_base_dir, search_condition)
  108. except (LineageParamValueError, LineageDirNotExistError) as error:
  109. log.error(str(error))
  110. log.exception(error)
  111. raise LineageParamSummaryPathError(str(error.message))
  112. try:
  113. lineage_objects = LineageOrganizer(data_manager, summary_base_dir).super_lineage_objs
  114. result = Querier(lineage_objects).filter_summary_lineage(
  115. condition=search_condition,
  116. added=added
  117. )
  118. except LineageSummaryParseException:
  119. result = {'object': [], 'count': 0}
  120. except (LineageQuerierParamException, LineageParamTypeError) as error:
  121. log.error(str(error))
  122. log.exception(error)
  123. raise LineageQuerySummaryDataError("Filter summary lineage failed.")
  124. return result
  125. def _convert_relative_path_to_abspath(summary_base_dir, search_condition):
  126. """
  127. Convert relative path to absolute path.
  128. Args:
  129. summary_base_dir (str): The summary base directory.
  130. search_condition (dict): The search condition.
  131. Returns:
  132. dict, the updated search_condition.
  133. Raises:
  134. LineageParamValueError: If the value of input_name is invalid.
  135. """
  136. if ("summary_dir" not in search_condition) or (not search_condition.get("summary_dir")):
  137. return search_condition
  138. summary_dir_condition = search_condition.get("summary_dir")
  139. for key in ['in', 'not_in']:
  140. if key in summary_dir_condition:
  141. summary_paths = []
  142. for summary_dir in summary_dir_condition.get(key):
  143. if summary_dir.startswith('./'):
  144. abs_dir = os.path.join(
  145. summary_base_dir, summary_dir[2:]
  146. )
  147. abs_dir = validate_path(abs_dir)
  148. else:
  149. abs_dir = validate_path(summary_dir)
  150. summary_paths.append(abs_dir)
  151. search_condition.get('summary_dir')[key] = summary_paths
  152. if 'eq' in summary_dir_condition:
  153. summary_dir = summary_dir_condition.get('eq')
  154. if summary_dir.startswith('./'):
  155. abs_dir = os.path.join(
  156. summary_base_dir, summary_dir[2:]
  157. )
  158. abs_dir = validate_path(abs_dir)
  159. else:
  160. abs_dir = validate_path(summary_dir)
  161. search_condition.get('summary_dir')['eq'] = abs_dir
  162. return search_condition
  163. def get_flattened_lineage(data_manager, search_condition=None):
  164. """
  165. Get lineage data in a table from data manager.
  166. Args:
  167. data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading.
  168. search_condition (dict): The search condition.
  169. Returns:
  170. Dict[str, list]: A dict contains keys and values from lineages.
  171. """
  172. summary_base_dir, flatten_dict, user_count = data_manager.summary_base_dir, {'train_id': []}, 0
  173. lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition).get("object", [])
  174. for index, lineage in enumerate(lineages):
  175. flatten_dict['train_id'].append(get_relative_path(lineage.get("summary_dir"), summary_base_dir))
  176. for key, val in _flatten_lineage(lineage.get('model_lineage', {})):
  177. if key.startswith(_USER_DEFINED_PREFIX) and key not in flatten_dict:
  178. if user_count > USER_DEFINED_INFO_LIMIT:
  179. log.warning("The user_defined_info has reached the limit %s. %r is ignored",
  180. USER_DEFINED_INFO_LIMIT, key)
  181. continue
  182. user_count += 1
  183. if key not in flatten_dict:
  184. flatten_dict[key] = [None] * index
  185. flatten_dict[key].append(_parse_value(val))
  186. for vals in flatten_dict.values():
  187. if len(vals) == index:
  188. vals.append(None)
  189. return flatten_dict
  190. def _flatten_lineage(lineage):
  191. """Flatten the lineage."""
  192. for key, val in lineage.items():
  193. if key == 'metric':
  194. for k, v in val.items():
  195. yield f'{_METRIC_PREFIX}{k}', v
  196. elif key == 'user_defined':
  197. for k, v in val.items():
  198. yield f'{_USER_DEFINED_PREFIX}{k}', v
  199. else:
  200. yield key, val
  201. def _parse_value(val):
  202. """Parse value."""
  203. if isinstance(val, str) and val.lower() in ['nan', 'inf']:
  204. return np.nan
  205. return val
  206. class LineageTable:
  207. """Wrap lineage data in a table."""
  208. _LOSS_NAME = "loss"
  209. _NOT_TUNABLE_NAMES = [_LOSS_NAME, "train_id", "device_num", "model_size",
  210. "test_dataset_count", "train_dataset_count"]
  211. def __init__(self, df: pd.DataFrame):
  212. self._df = df
  213. self.train_ids = self._df["train_id"].tolist()
  214. self._drop_columns_info = []
  215. self._remove_unsupported_columns()
  216. def _remove_unsupported_columns(self):
  217. """Remove unsupported columns."""
  218. columns_to_drop = []
  219. for name, data in self._df.iteritems():
  220. if not is_simple_numpy_number(data.dtype):
  221. columns_to_drop.append(name)
  222. if columns_to_drop:
  223. log.debug("Unsupported columns: %s", columns_to_drop)
  224. self._df = self._df.drop(columns=columns_to_drop)
  225. for name in columns_to_drop:
  226. if not name.startswith(_USER_DEFINED_PREFIX):
  227. continue
  228. self._drop_columns_info.append({
  229. "name": name,
  230. "unselected": True,
  231. "reason_code": ReasonCode.NOT_ALL_NUMBERS.value
  232. })
  233. @property
  234. def target_names(self):
  235. """Get names for optimize targets (eg loss, accuracy)."""
  236. target_names = [name for name in self._df.columns if name.startswith(_METRIC_PREFIX)]
  237. if self._LOSS_NAME in self._df.columns:
  238. target_names.append(self._LOSS_NAME)
  239. return target_names
  240. @property
  241. def hyper_param_names(self, tunable=True):
  242. """Get hyper param names."""
  243. blocked_names = self._get_blocked_names(tunable)
  244. hyper_param_names = [
  245. name for name in self._df.columns
  246. if not name.startswith(_METRIC_PREFIX) and name not in blocked_names]
  247. if self._LOSS_NAME in hyper_param_names:
  248. hyper_param_names.remove(self._LOSS_NAME)
  249. return hyper_param_names
  250. def _get_blocked_names(self, tunable):
  251. if tunable:
  252. block_names = self._NOT_TUNABLE_NAMES
  253. else:
  254. block_names = []
  255. return block_names
  256. @property
  257. def user_defined_hyper_param_names(self):
  258. """Get user defined hyper param names."""
  259. names = [name for name in self._df.columns if name.startswith(_USER_DEFINED_PREFIX)]
  260. return names
  261. def get_column(self, name):
  262. """
  263. Get data for specified column.
  264. Args:
  265. name (str): column name.
  266. Returns:
  267. np.ndarray, specified column.
  268. """
  269. return self._df[name]
  270. def get_column_values(self, name):
  271. """
  272. Get data for specified column.
  273. Args:
  274. name (str): column name.
  275. Returns:
  276. list, specified column data. If value is np.nan, transform to None.
  277. """
  278. return [None if np.isnan(num) else num for num in self._df[name].tolist()]
  279. @property
  280. def dataframe_data(self):
  281. """Get the DataFrame."""
  282. return self._df
  283. @property
  284. def drop_column_info(self):
  285. """Get dropped columns info."""
  286. return self._drop_columns_info