You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 8.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the model lineage python api."""
  16. import numpy as np
  17. import pandas as pd
  18. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySummaryDataError, \
  19. LineageQuerierParamException, LineageSearchConditionParamError, LineageParamTypeError, LineageSummaryParseException
  20. from mindinsight.lineagemgr.common.log import logger as log
  21. from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter
  22. from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition, validate_condition
  23. from mindinsight.lineagemgr.lineage_parser import LineageOrganizer
  24. from mindinsight.lineagemgr.querier.querier import Querier
  25. from mindinsight.optimizer.common.enums import ReasonCode
  26. from mindinsight.optimizer.utils.utils import is_simple_numpy_number
  27. from mindinsight.utils.exceptions import MindInsightException
  28. METRIC_PREFIX = "[M]"
  29. USER_DEFINED_PREFIX = "[U]"
  30. USER_DEFINED_INFO_LIMIT = 100
  31. def filter_summary_lineage(data_manager, search_condition=None):
  32. """
  33. Filter summary lineage from data_manager or parsing from summaries.
  34. One of data_manager or summary_base_dir needs to be specified. Support getting
  35. super_lineage_obj from data_manager or parsing summaries by summary_base_dir.
  36. Args:
  37. data_manager (DataManager): Data manager defined as
  38. mindinsight.datavisual.data_transform.data_manager.DataManager
  39. search_condition (dict): The search condition.
  40. """
  41. search_condition = {} if search_condition is None else search_condition
  42. try:
  43. validate_condition(search_condition)
  44. validate_search_model_condition(SearchModelConditionParameter, search_condition)
  45. except MindInsightException as error:
  46. log.error(str(error))
  47. log.exception(error)
  48. raise LineageSearchConditionParamError(str(error.message))
  49. try:
  50. lineage_objects = LineageOrganizer(data_manager).super_lineage_objs
  51. result = Querier(lineage_objects).filter_summary_lineage(condition=search_condition)
  52. except LineageSummaryParseException:
  53. result = {'object': [], 'count': 0}
  54. except (LineageQuerierParamException, LineageParamTypeError) as error:
  55. log.error(str(error))
  56. log.exception(error)
  57. raise LineageQuerySummaryDataError("Filter summary lineage failed.")
  58. return result
  59. def get_flattened_lineage(data_manager, search_condition=None):
  60. """
  61. Get lineage data in a table from data manager.
  62. Args:
  63. data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading.
  64. search_condition (dict): The search condition.
  65. Returns:
  66. Dict[str, list]: A dict contains keys and values from lineages.
  67. """
  68. flatten_dict, user_count = {'train_id': []}, 0
  69. lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition).get("object", [])
  70. for index, lineage in enumerate(lineages):
  71. flatten_dict['train_id'].append(lineage.get("summary_dir"))
  72. for key, val in _flatten_lineage(lineage.get('model_lineage', {})):
  73. if key.startswith(USER_DEFINED_PREFIX) and key not in flatten_dict:
  74. if user_count > USER_DEFINED_INFO_LIMIT:
  75. log.warning("The user_defined_info has reached the limit %s. %r is ignored",
  76. USER_DEFINED_INFO_LIMIT, key)
  77. continue
  78. user_count += 1
  79. if key not in flatten_dict:
  80. flatten_dict[key] = [None] * index
  81. flatten_dict[key].append(_parse_value(val))
  82. for vals in flatten_dict.values():
  83. if len(vals) == index:
  84. vals.append(None)
  85. return flatten_dict
  86. def _flatten_lineage(lineage):
  87. """Flatten the lineage."""
  88. for key, val in lineage.items():
  89. if key == 'metric':
  90. for k, v in val.items():
  91. yield f'{METRIC_PREFIX}{k}', v
  92. elif key == 'user_defined':
  93. for k, v in val.items():
  94. yield f'{USER_DEFINED_PREFIX}{k}', v
  95. else:
  96. yield key, val
  97. def _parse_value(val):
  98. """Parse value."""
  99. if isinstance(val, str) and val.lower() in ['nan', 'inf']:
  100. return np.nan
  101. return val
  102. class LineageTable:
  103. """Wrap lineage data in a table."""
  104. _LOSS_NAME = "loss"
  105. _NOT_TUNABLE_NAMES = [_LOSS_NAME, "train_id", "device_num", "model_size",
  106. "test_dataset_count", "train_dataset_count"]
  107. def __init__(self, df: pd.DataFrame):
  108. self._df = df
  109. self.train_ids = self._df["train_id"].tolist()
  110. self._drop_columns_info = []
  111. self._remove_unsupported_columns()
  112. def _remove_unsupported_columns(self):
  113. """Remove unsupported columns."""
  114. columns_to_drop = []
  115. for name, data in self._df.iteritems():
  116. if not is_simple_numpy_number(data.dtype):
  117. columns_to_drop.append(name)
  118. if columns_to_drop:
  119. log.debug("Unsupported columns: %s", columns_to_drop)
  120. self._df = self._df.drop(columns=columns_to_drop)
  121. for name in columns_to_drop:
  122. if not name.startswith(USER_DEFINED_PREFIX):
  123. continue
  124. self._drop_columns_info.append({
  125. "name": name,
  126. "unselected": True,
  127. "reason_code": ReasonCode.NOT_ALL_NUMBERS.value
  128. })
  129. @property
  130. def target_names(self):
  131. """Get names for optimize targets (eg loss, accuracy)."""
  132. target_names = [name for name in self._df.columns if name.startswith(METRIC_PREFIX)]
  133. if self._LOSS_NAME in self._df.columns:
  134. target_names.append(self._LOSS_NAME)
  135. return target_names
  136. @property
  137. def hyper_param_names(self, tunable=True):
  138. """Get hyper param names."""
  139. blocked_names = self._get_blocked_names(tunable)
  140. hyper_param_names = [
  141. name for name in self._df.columns
  142. if not name.startswith(METRIC_PREFIX) and name not in blocked_names]
  143. if self._LOSS_NAME in hyper_param_names:
  144. hyper_param_names.remove(self._LOSS_NAME)
  145. return hyper_param_names
  146. def _get_blocked_names(self, tunable):
  147. if tunable:
  148. block_names = self._NOT_TUNABLE_NAMES
  149. else:
  150. block_names = []
  151. return block_names
  152. @property
  153. def user_defined_hyper_param_names(self):
  154. """Get user defined hyper param names."""
  155. names = [name for name in self._df.columns if name.startswith(USER_DEFINED_PREFIX)]
  156. return names
  157. def get_column(self, name):
  158. """
  159. Get data for specified column.
  160. Args:
  161. name (str): column name.
  162. Returns:
  163. np.ndarray, specified column.
  164. """
  165. return self._df[name]
  166. def get_column_values(self, name):
  167. """
  168. Get data for specified column.
  169. Args:
  170. name (str): column name.
  171. Returns:
  172. list, specified column data. If value is np.nan, transform to None.
  173. """
  174. return [None if np.isnan(num) else num for num in self._df[name].tolist()]
  175. @property
  176. def dataframe_data(self):
  177. """Get the DataFrame."""
  178. return self._df
  179. @property
  180. def drop_column_info(self):
  181. """Get dropped columns info."""
  182. return self._drop_columns_info
  183. def get_lineage_table(data_manager, search_condition=None):
  184. """Get lineage table from data_manager."""
  185. lineage_table = get_flattened_lineage(data_manager, search_condition)
  186. lineage_table = LineageTable(pd.DataFrame(lineage_table))
  187. return lineage_table