You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

query_model.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define lineage info model."""
  16. import json
  17. from collections import namedtuple
  18. from google.protobuf.json_format import MessageToDict
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import \
  20. LineageEventFieldNotExistException, LineageEventNotExistException
  21. from mindinsight.lineagemgr.summary._summary_adapter import organize_graph
  22. Field = namedtuple('Field', ['base_name', 'sub_name'])
  23. FIELD_MAPPING = {
  24. "summary_dir": Field('summary_dir', None),
  25. "loss_function": Field("hyper_parameters", 'loss_function'),
  26. "train_dataset_path": Field('train_dataset', 'train_dataset_path'),
  27. "train_dataset_count": Field("train_dataset", 'train_dataset_size'),
  28. "test_dataset_path": Field('valid_dataset', 'valid_dataset_path'),
  29. "test_dataset_count": Field('valid_dataset', 'valid_dataset_size'),
  30. "network": Field('algorithm', 'network'),
  31. "optimizer": Field('hyper_parameters', 'optimizer'),
  32. "learning_rate": Field('hyper_parameters', 'learning_rate'),
  33. "epoch": Field('hyper_parameters', 'epoch'),
  34. "batch_size": Field('hyper_parameters', 'batch_size'),
  35. "device_num": Field('hyper_parameters', 'device_num'),
  36. "loss": Field('algorithm', 'loss'),
  37. "model_size": Field('model', 'size'),
  38. "dataset_mark": Field('dataset_mark', None)
  39. }
  40. class LineageObj:
  41. """
  42. Lineage information class.
  43. An instance of the class hold lineage information for a training session.
  44. Args:
  45. summary_dir (str): Summary log dir.
  46. kwargs (dict): Params to init the instance.
  47. - train_lineage (Event): Train lineage object.
  48. - evaluation_lineage (Event): Evaluation lineage object.
  49. - dataset_graph (Event): Dataset graph object.
  50. - user_defined_info (Event): User defined info object.
  51. Raises:
  52. LineageEventNotExistException: If train and evaluation event not exist.
  53. LineageEventFieldNotExistException: If the special event field not exist.
  54. """
  55. _name_train_lineage = 'train_lineage'
  56. _name_evaluation_lineage = 'evaluation_lineage'
  57. _name_summary_dir = 'summary_dir'
  58. _name_metric = 'metric'
  59. _name_hyper_parameters = 'hyper_parameters'
  60. _name_algorithm = 'algorithm'
  61. _name_train_dataset = 'train_dataset'
  62. _name_model = 'model'
  63. _name_valid_dataset = 'valid_dataset'
  64. _name_dataset_graph = 'dataset_graph'
  65. _name_dataset_mark = 'dataset_mark'
  66. _name_user_defined = 'user_defined'
  67. _name_model_lineage = 'model_lineage'
  68. def __init__(self, summary_dir, **kwargs):
  69. self._lineage_info = {
  70. self._name_summary_dir: summary_dir
  71. }
  72. self._init_lineage()
  73. self.parse_and_update_lineage(**kwargs)
  74. def _init_lineage(self):
  75. """Init lineage info."""
  76. # train
  77. self._lineage_info[self._name_model] = {}
  78. self._lineage_info[self._name_algorithm] = {}
  79. self._lineage_info[self._name_hyper_parameters] = {}
  80. self._lineage_info[self._name_train_dataset] = {}
  81. # eval
  82. self._lineage_info[self._name_metric] = {}
  83. self._lineage_info[self._name_valid_dataset] = {}
  84. # dataset graph
  85. self._lineage_info[self._name_dataset_graph] = {}
  86. # user defined
  87. self._lineage_info[self._name_user_defined] = {}
  88. def parse_and_update_lineage(self, **kwargs):
  89. """Parse and update lineage."""
  90. user_defined_info_list = kwargs.get('user_defined_info', [])
  91. train_lineage = kwargs.get('train_lineage')
  92. evaluation_lineage = kwargs.get('evaluation_lineage')
  93. dataset_graph = kwargs.get('dataset_graph')
  94. if not any([train_lineage, evaluation_lineage, dataset_graph]):
  95. raise LineageEventNotExistException()
  96. # If new train lineage, will clean the lineage saved before.
  97. if train_lineage is not None or dataset_graph is not None:
  98. self._init_lineage()
  99. self._parse_user_defined_info(user_defined_info_list)
  100. self._parse_train_lineage(train_lineage)
  101. self._parse_evaluation_lineage(evaluation_lineage)
  102. self._parse_dataset_graph(dataset_graph)
  103. self._filtration_result = self._organize_filtration_result()
  104. @property
  105. def summary_dir(self):
  106. """
  107. Get summary log dir.
  108. Returns:
  109. str, the summary log dir.
  110. """
  111. return self._lineage_info.get(self._name_summary_dir)
  112. @property
  113. def metric(self):
  114. """
  115. Get metric information.
  116. Returns:
  117. dict, the metric information.
  118. """
  119. return self._lineage_info.get(self._name_metric)
  120. @property
  121. def user_defined(self):
  122. """
  123. Get user defined information.
  124. Returns:
  125. dict, the user defined information.
  126. """
  127. return self._lineage_info.get(self._name_user_defined)
  128. @property
  129. def hyper_parameters(self):
  130. """
  131. Get hyperparameters.
  132. Returns:
  133. dict, the hyperparameters.
  134. """
  135. return self._lineage_info.get(self._name_hyper_parameters)
  136. @property
  137. def algorithm(self):
  138. """
  139. Get algorithm.
  140. Returns:
  141. dict, the algorithm.
  142. """
  143. return self._lineage_info.get(self._name_algorithm)
  144. @property
  145. def train_dataset(self):
  146. """
  147. Get train dataset information.
  148. Returns:
  149. dict, the train dataset information.
  150. """
  151. return self._lineage_info.get(self._name_train_dataset)
  152. @property
  153. def model(self):
  154. """
  155. Get model information.
  156. Returns:
  157. dict, the model information.
  158. """
  159. return self._lineage_info.get(self._name_model)
  160. @property
  161. def valid_dataset(self):
  162. """
  163. Get valid dataset information.
  164. Returns:
  165. dict, the valid dataset information.
  166. """
  167. return self._lineage_info.get(self._name_valid_dataset)
  168. @property
  169. def dataset_graph(self):
  170. """
  171. Get dataset_graph.
  172. Returns:
  173. dict, the dataset graph information.
  174. """
  175. return self._lineage_info.get(self._name_dataset_graph)
  176. @property
  177. def dataset_mark(self):
  178. """
  179. Get dataset_mark.
  180. Returns:
  181. dict, the dataset mark information.
  182. """
  183. return self._lineage_info.get(self._name_dataset_mark)
  184. @dataset_mark.setter
  185. def dataset_mark(self, dataset_mark):
  186. """
  187. Set dataset mark.
  188. Args:
  189. dataset_mark (int): Dataset mark.
  190. """
  191. self._lineage_info[self._name_dataset_mark] = dataset_mark
  192. # update dataset_mark into filtration result
  193. self._filtration_result[self._name_dataset_mark] = dataset_mark
  194. def get_summary_info(self, filter_keys: list):
  195. """
  196. Get the summary lineage information.
  197. Returns the content corresponding to the specified field in the filter
  198. key. The contents of the filter key include `metric`, `hyper_parameters`,
  199. `algorithm`, `train_dataset`, `valid_dataset` and `model`. You can
  200. specify multiple filter keys in the `filter_keys`
  201. Args:
  202. filter_keys (list): Filter keys.
  203. Returns:
  204. dict, the summary lineage information.
  205. """
  206. result = {
  207. self._name_summary_dir: self.summary_dir,
  208. }
  209. for key in filter_keys:
  210. result[key] = getattr(self, key)
  211. return result
  212. def to_dataset_lineage_dict(self):
  213. """
  214. Returns the dataset part lineage information.
  215. Returns:
  216. dict, the dataset lineage information.
  217. """
  218. dataset_lineage = {
  219. key: self._filtration_result.get(key)
  220. for key in [self._name_summary_dir, self._name_dataset_graph]
  221. }
  222. return dataset_lineage
  223. def to_model_lineage_dict(self):
  224. """
  225. Returns the model part lineage information.
  226. Returns:
  227. dict, the model lineage information.
  228. """
  229. filtration_result = dict(self._filtration_result)
  230. filtration_result.pop(self._name_dataset_graph)
  231. model_lineage = dict()
  232. model_lineage.update({self._name_summary_dir: filtration_result.pop(self._name_summary_dir)})
  233. model_lineage.update({self._name_model_lineage: filtration_result})
  234. return model_lineage
  235. def get_value_by_key(self, key):
  236. """
  237. Get the value based on the key in `FIELD_MAPPING` or
  238. the key prefixed with `metric/` or `user_defined/`.
  239. Args:
  240. key (str): The key in `FIELD_MAPPING`
  241. or prefixed with `metric/` or `user_defined/`.
  242. Returns:
  243. object, the value.
  244. """
  245. if key.startswith(('metric/', 'user_defined/')):
  246. key_name, sub_key = key.split('/', 1)
  247. sub_value_name = self._name_metric if key_name == 'metric' else self._name_user_defined
  248. sub_value = self._filtration_result.get(sub_value_name)
  249. if sub_value:
  250. return sub_value.get(sub_key)
  251. return self._filtration_result.get(key)
  252. def _organize_filtration_result(self):
  253. """
  254. Organize filtration result.
  255. Returns:
  256. dict, the filtration result.
  257. """
  258. result = {}
  259. for key, field in FIELD_MAPPING.items():
  260. if field.base_name is not None:
  261. base_attr = getattr(self, field.base_name)
  262. result[key] = base_attr.get(field.sub_name) \
  263. if field.sub_name else base_attr
  264. # add metric into filtration result
  265. result[self._name_metric] = self.metric
  266. result[self._name_user_defined] = self.user_defined
  267. # add dataset_graph into filtration result
  268. result[self._name_dataset_graph] = getattr(self, self._name_dataset_graph)
  269. return result
  270. def _parse_train_lineage(self, train_lineage):
  271. """
  272. Parse train lineage.
  273. Args:
  274. train_lineage (Event): Train lineage.
  275. """
  276. if train_lineage is None:
  277. return
  278. event_dict = MessageToDict(
  279. train_lineage, preserving_proto_field_name=True
  280. )
  281. train_dict = event_dict.get(self._name_train_lineage)
  282. if train_dict is None:
  283. raise LineageEventFieldNotExistException(
  284. self._name_train_lineage
  285. )
  286. # when MessageToDict is converted to dict, int64 type is converted
  287. # to string, so we convert it to an int in python
  288. if train_dict.get(self._name_model):
  289. model_size = train_dict.get(self._name_model).get('size')
  290. if model_size:
  291. train_dict[self._name_model]['size'] = int(model_size)
  292. self._lineage_info.update(**train_dict)
  293. def _parse_evaluation_lineage(self, evaluation_lineage):
  294. """
  295. Parse evaluation lineage.
  296. Args:
  297. evaluation_lineage (Event): Evaluation lineage.
  298. """
  299. if evaluation_lineage is None:
  300. return
  301. event_dict = MessageToDict(
  302. evaluation_lineage, preserving_proto_field_name=True
  303. )
  304. evaluation_dict = event_dict.get(self._name_evaluation_lineage)
  305. if evaluation_dict is None:
  306. raise LineageEventFieldNotExistException(
  307. self._name_evaluation_lineage
  308. )
  309. self._lineage_info.update(**evaluation_dict)
  310. metric = self._lineage_info.get(self._name_metric)
  311. self._lineage_info[self._name_metric] = json.loads(metric) if metric else {}
  312. def _parse_dataset_graph(self, dataset_graph):
  313. """
  314. Parse dataset graph.
  315. Args:
  316. dataset_graph (Event): Dataset graph.
  317. """
  318. if dataset_graph is not None:
  319. # convert message to dict
  320. event_dict = organize_graph(dataset_graph.dataset_graph)
  321. if event_dict is None:
  322. raise LineageEventFieldNotExistException(self._name_evaluation_lineage)
  323. self._lineage_info[self._name_dataset_graph] = event_dict if event_dict else {}
  324. def _parse_user_defined_info(self, user_defined_info_list):
  325. """
  326. Parse user defined info.
  327. Args:
  328. user_defined_info_list (list): user defined info list.
  329. """
  330. if not user_defined_info_list:
  331. return
  332. for user_defined_info in user_defined_info_list:
  333. self._lineage_info[self._name_user_defined].update(user_defined_info)