You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 9.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the model lineage python api."""
  16. import numpy as np
  17. import pandas as pd
  18. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySummaryDataError, \
  19. LineageQuerierParamException, LineageSearchConditionParamError, LineageParamTypeError, LineageSummaryParseException
  20. from mindinsight.lineagemgr.common.log import logger as log
  21. from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter
  22. from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition, validate_condition
  23. from mindinsight.lineagemgr.common.validator.validate_path import validate_and_normalize_path
  24. from mindinsight.lineagemgr.lineage_parser import LineageOrganizer
  25. from mindinsight.lineagemgr.querier.querier import Querier
  26. from mindinsight.optimizer.common.enums import ReasonCode
  27. from mindinsight.optimizer.utils.utils import is_simple_numpy_number
  28. from mindinsight.utils.exceptions import MindInsightException
  29. METRIC_PREFIX = "[M]"
  30. USER_DEFINED_PREFIX = "[U]"
  31. USER_DEFINED_INFO_LIMIT = 100
  32. def filter_summary_lineage(data_manager=None, summary_base_dir=None, search_condition=None):
  33. """
  34. Filter summary lineage from data_manager or parsing from summaries.
  35. One of data_manager or summary_base_dir needs to be specified. Support getting
  36. super_lineage_obj from data_manager or parsing summaries by summary_base_dir.
  37. Args:
  38. data_manager (DataManager): Data manager defined as
  39. mindinsight.datavisual.data_transform.data_manager.DataManager
  40. summary_base_dir (str): The summary base directory. It contains summary
  41. directories generated by training.
  42. search_condition (dict): The search condition.
  43. """
  44. if data_manager is None and summary_base_dir is None:
  45. raise LineageParamTypeError("One of data_manager or summary_base_dir needs to be specified.")
  46. if data_manager is None:
  47. summary_base_dir = validate_and_normalize_path(summary_base_dir, 'summary_base_dir')
  48. else:
  49. summary_base_dir = data_manager.summary_base_dir
  50. search_condition = {} if search_condition is None else search_condition
  51. try:
  52. validate_condition(search_condition)
  53. validate_search_model_condition(SearchModelConditionParameter, search_condition)
  54. except MindInsightException as error:
  55. log.error(str(error))
  56. log.exception(error)
  57. raise LineageSearchConditionParamError(str(error.message))
  58. try:
  59. lineage_objects = LineageOrganizer(data_manager, summary_base_dir).super_lineage_objs
  60. result = Querier(lineage_objects).filter_summary_lineage(condition=search_condition)
  61. except LineageSummaryParseException:
  62. result = {'object': [], 'count': 0}
  63. except (LineageQuerierParamException, LineageParamTypeError) as error:
  64. log.error(str(error))
  65. log.exception(error)
  66. raise LineageQuerySummaryDataError("Filter summary lineage failed.")
  67. return result
  68. def get_flattened_lineage(data_manager=None, summary_base_dir=None, search_condition=None):
  69. """
  70. Get lineage data in a table from data manager.
  71. Args:
  72. data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading.
  73. summary_base_dir (str): The base directory for train jobs.
  74. search_condition (dict): The search condition.
  75. Returns:
  76. Dict[str, list]: A dict contains keys and values from lineages.
  77. """
  78. flatten_dict, user_count = {'train_id': []}, 0
  79. lineages = filter_summary_lineage(data_manager, summary_base_dir, search_condition).get("object", [])
  80. for index, lineage in enumerate(lineages):
  81. flatten_dict['train_id'].append(lineage.get("summary_dir"))
  82. for key, val in _flatten_lineage(lineage.get('model_lineage', {})):
  83. if key.startswith(USER_DEFINED_PREFIX) and key not in flatten_dict:
  84. if user_count > USER_DEFINED_INFO_LIMIT:
  85. log.warning("The user_defined_info has reached the limit %s. %r is ignored",
  86. USER_DEFINED_INFO_LIMIT, key)
  87. continue
  88. user_count += 1
  89. if key not in flatten_dict:
  90. flatten_dict[key] = [None] * index
  91. flatten_dict[key].append(_parse_value(val))
  92. for vals in flatten_dict.values():
  93. if len(vals) == index:
  94. vals.append(None)
  95. return flatten_dict
  96. def _flatten_lineage(lineage):
  97. """Flatten the lineage."""
  98. for key, val in lineage.items():
  99. if key == 'metric':
  100. for k, v in val.items():
  101. yield f'{METRIC_PREFIX}{k}', v
  102. elif key == 'user_defined':
  103. for k, v in val.items():
  104. yield f'{USER_DEFINED_PREFIX}{k}', v
  105. else:
  106. yield key, val
  107. def _parse_value(val):
  108. """Parse value."""
  109. if isinstance(val, str) and val.lower() in ['nan', 'inf']:
  110. return np.nan
  111. return val
  112. class LineageTable:
  113. """Wrap lineage data in a table."""
  114. _LOSS_NAME = "loss"
  115. _NOT_TUNABLE_NAMES = [_LOSS_NAME, "train_id", "device_num", "model_size",
  116. "test_dataset_count", "train_dataset_count"]
  117. def __init__(self, df: pd.DataFrame):
  118. self._df = df
  119. self.train_ids = self._df["train_id"].tolist()
  120. self._drop_columns_info = []
  121. self._remove_unsupported_columns()
  122. def _remove_unsupported_columns(self):
  123. """Remove unsupported columns."""
  124. columns_to_drop = []
  125. for name, data in self._df.iteritems():
  126. if not is_simple_numpy_number(data.dtype):
  127. columns_to_drop.append(name)
  128. if columns_to_drop:
  129. log.debug("Unsupported columns: %s", columns_to_drop)
  130. self._df = self._df.drop(columns=columns_to_drop)
  131. for name in columns_to_drop:
  132. if not name.startswith(USER_DEFINED_PREFIX):
  133. continue
  134. self._drop_columns_info.append({
  135. "name": name,
  136. "unselected": True,
  137. "reason_code": ReasonCode.NOT_ALL_NUMBERS.value
  138. })
  139. @property
  140. def target_names(self):
  141. """Get names for optimize targets (eg loss, accuracy)."""
  142. target_names = [name for name in self._df.columns if name.startswith(METRIC_PREFIX)]
  143. if self._LOSS_NAME in self._df.columns:
  144. target_names.append(self._LOSS_NAME)
  145. return target_names
  146. @property
  147. def hyper_param_names(self, tunable=True):
  148. """Get hyper param names."""
  149. blocked_names = self._get_blocked_names(tunable)
  150. hyper_param_names = [
  151. name for name in self._df.columns
  152. if not name.startswith(METRIC_PREFIX) and name not in blocked_names]
  153. if self._LOSS_NAME in hyper_param_names:
  154. hyper_param_names.remove(self._LOSS_NAME)
  155. return hyper_param_names
  156. def _get_blocked_names(self, tunable):
  157. if tunable:
  158. block_names = self._NOT_TUNABLE_NAMES
  159. else:
  160. block_names = []
  161. return block_names
  162. @property
  163. def user_defined_hyper_param_names(self):
  164. """Get user defined hyper param names."""
  165. names = [name for name in self._df.columns if name.startswith(USER_DEFINED_PREFIX)]
  166. return names
  167. def get_column(self, name):
  168. """
  169. Get data for specified column.
  170. Args:
  171. name (str): column name.
  172. Returns:
  173. np.ndarray, specified column.
  174. """
  175. return self._df[name]
  176. def get_column_values(self, name):
  177. """
  178. Get data for specified column.
  179. Args:
  180. name (str): column name.
  181. Returns:
  182. list, specified column data. If value is np.nan, transform to None.
  183. """
  184. return [None if np.isnan(num) else num for num in self._df[name].tolist()]
  185. @property
  186. def dataframe_data(self):
  187. """Get the DataFrame."""
  188. return self._df
  189. @property
  190. def drop_column_info(self):
  191. """Get dropped columns info."""
  192. return self._drop_columns_info
  193. def get_lineage_table(data_manager=None, summary_base_dir=None, search_condition=None):
  194. """Get lineage table from data_manager."""
  195. lineage_table = get_flattened_lineage(data_manager, summary_base_dir, search_condition)
  196. lineage_table = LineageTable(pd.DataFrame(lineage_table))
  197. return lineage_table