You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

query_model.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define lineage info model."""
  16. import json
  17. from collections import namedtuple
  18. from google.protobuf.json_format import MessageToDict
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import \
  20. LineageEventFieldNotExistException, LineageEventNotExistException
  21. from mindinsight.lineagemgr.summary._summary_adapter import organize_graph
  22. Field = namedtuple('Field', ['base_name', 'sub_name'])
  23. FIELD_MAPPING = {
  24. "summary_dir": Field('summary_dir', None),
  25. "loss_function": Field("hyper_parameters", 'loss_function'),
  26. "train_dataset_path": Field('train_dataset', 'train_dataset_path'),
  27. "train_dataset_count": Field("train_dataset", 'train_dataset_size'),
  28. "test_dataset_path": Field('valid_dataset', 'valid_dataset_path'),
  29. "test_dataset_count": Field('valid_dataset', 'valid_dataset_size'),
  30. "network": Field('algorithm', 'network'),
  31. "optimizer": Field('hyper_parameters', 'optimizer'),
  32. "learning_rate": Field('hyper_parameters', 'learning_rate'),
  33. "epoch": Field('hyper_parameters', 'epoch'),
  34. "batch_size": Field('hyper_parameters', 'batch_size'),
  35. "device_num": Field('hyper_parameters', 'device_num'),
  36. "loss": Field('algorithm', 'loss'),
  37. "model_size": Field('model', 'size'),
  38. "dataset_mark": Field('dataset_mark', None)
  39. }
  40. class LineageObj:
  41. """
  42. Lineage information class.
  43. An instance of the class hold lineage information for a training session.
  44. Args:
  45. summary_dir (str): Summary log dir.
  46. kwargs (dict): Params to init the instance.
  47. - train_lineage (Event): Train lineage object.
  48. - evaluation_lineage (Event): Evaluation lineage object.
  49. - dataset_graph (Event): Dataset graph object.
  50. - user_defined_info (Event): User defined info object.
  51. Raises:
  52. LineageEventNotExistException: If train and evaluation event not exist.
  53. LineageEventFieldNotExistException: If the special event field not exist.
  54. """
  55. _name_train_lineage = 'train_lineage'
  56. _name_evaluation_lineage = 'evaluation_lineage'
  57. _name_summary_dir = 'summary_dir'
  58. _name_metric = 'metric'
  59. _name_hyper_parameters = 'hyper_parameters'
  60. _name_algorithm = 'algorithm'
  61. _name_train_dataset = 'train_dataset'
  62. _name_model = 'model'
  63. _name_valid_dataset = 'valid_dataset'
  64. _name_dataset_graph = 'dataset_graph'
  65. _name_dataset_mark = 'dataset_mark'
  66. _name_user_defined = 'user_defined'
  67. _name_model_lineage = 'model_lineage'
  68. def __init__(self, summary_dir, **kwargs):
  69. self._lineage_info = {
  70. self._name_summary_dir: summary_dir
  71. }
  72. user_defined_info_list = kwargs.get('user_defined_info', [])
  73. train_lineage = kwargs.get('train_lineage')
  74. evaluation_lineage = kwargs.get('evaluation_lineage')
  75. dataset_graph = kwargs.get('dataset_graph')
  76. if not any([train_lineage, evaluation_lineage, dataset_graph]):
  77. raise LineageEventNotExistException()
  78. self._parse_user_defined_info(user_defined_info_list)
  79. self._parse_train_lineage(train_lineage)
  80. self._parse_evaluation_lineage(evaluation_lineage)
  81. self._parse_dataset_graph(dataset_graph)
  82. self._filtration_result = self._organize_filtration_result()
  83. @property
  84. def summary_dir(self):
  85. """
  86. Get summary log dir.
  87. Returns:
  88. str, the summary log dir.
  89. """
  90. return self._lineage_info.get(self._name_summary_dir)
  91. @property
  92. def metric(self):
  93. """
  94. Get metric information.
  95. Returns:
  96. dict, the metric information.
  97. """
  98. return self._lineage_info.get(self._name_metric)
  99. @property
  100. def user_defined(self):
  101. """
  102. Get user defined information.
  103. Returns:
  104. dict, the user defined information.
  105. """
  106. return self._lineage_info.get(self._name_user_defined)
  107. @property
  108. def hyper_parameters(self):
  109. """
  110. Get hyperparameters.
  111. Returns:
  112. dict, the hyperparameters.
  113. """
  114. return self._lineage_info.get(self._name_hyper_parameters)
  115. @property
  116. def algorithm(self):
  117. """
  118. Get algorithm.
  119. Returns:
  120. dict, the algorithm.
  121. """
  122. return self._lineage_info.get(self._name_algorithm)
  123. @property
  124. def train_dataset(self):
  125. """
  126. Get train dataset information.
  127. Returns:
  128. dict, the train dataset information.
  129. """
  130. return self._lineage_info.get(self._name_train_dataset)
  131. @property
  132. def model(self):
  133. """
  134. Get model information.
  135. Returns:
  136. dict, the model information.
  137. """
  138. return self._lineage_info.get(self._name_model)
  139. @property
  140. def valid_dataset(self):
  141. """
  142. Get valid dataset information.
  143. Returns:
  144. dict, the valid dataset information.
  145. """
  146. return self._lineage_info.get(self._name_valid_dataset)
  147. @property
  148. def dataset_graph(self):
  149. """
  150. Get dataset_graph.
  151. Returns:
  152. dict, the dataset graph information.
  153. """
  154. return self._lineage_info.get(self._name_dataset_graph)
  155. @property
  156. def dataset_mark(self):
  157. """
  158. Get dataset_mark.
  159. Returns:
  160. dict, the dataset mark information.
  161. """
  162. return self._lineage_info.get(self._name_dataset_mark)
  163. @dataset_mark.setter
  164. def dataset_mark(self, dataset_mark):
  165. """
  166. Set dataset mark.
  167. Args:
  168. dataset_mark (int): Dataset mark.
  169. """
  170. self._lineage_info[self._name_dataset_mark] = dataset_mark
  171. # update dataset_mark into filtration result
  172. self._filtration_result[self._name_dataset_mark] = dataset_mark
  173. def get_summary_info(self, filter_keys: list):
  174. """
  175. Get the summary lineage information.
  176. Returns the content corresponding to the specified field in the filter
  177. key. The contents of the filter key include `metric`, `hyper_parameters`,
  178. `algorithm`, `train_dataset`, `valid_dataset` and `model`. You can
  179. specify multiple filter keys in the `filter_keys`
  180. Args:
  181. filter_keys (list): Filter keys.
  182. Returns:
  183. dict, the summary lineage information.
  184. """
  185. result = {
  186. self._name_summary_dir: self.summary_dir,
  187. }
  188. for key in filter_keys:
  189. result[key] = getattr(self, key)
  190. return result
  191. def to_dataset_lineage_dict(self):
  192. """
  193. Returns the dataset part lineage information.
  194. Returns:
  195. dict, the dataset lineage information.
  196. """
  197. dataset_lineage = {
  198. key: self._filtration_result.get(key)
  199. for key in [self._name_summary_dir, self._name_dataset_graph]
  200. }
  201. return dataset_lineage
  202. def to_model_lineage_dict(self):
  203. """
  204. Returns the model part lineage information.
  205. Returns:
  206. dict, the model lineage information.
  207. """
  208. filtration_result = dict(self._filtration_result)
  209. filtration_result.pop(self._name_dataset_graph)
  210. model_lineage = dict()
  211. model_lineage.update({self._name_summary_dir: filtration_result.pop(self._name_summary_dir)})
  212. model_lineage.update({self._name_model_lineage: filtration_result})
  213. return model_lineage
  214. def get_value_by_key(self, key):
  215. """
  216. Get the value based on the key in `FIELD_MAPPING` or
  217. the key prefixed with `metric/` or `user_defined/`.
  218. Args:
  219. key (str): The key in `FIELD_MAPPING`
  220. or prefixed with `metric/` or `user_defined/`.
  221. Returns:
  222. object, the value.
  223. """
  224. if key.startswith(('metric/', 'user_defined/')):
  225. key_name, sub_key = key.split('/', 1)
  226. sub_value_name = self._name_metric if key_name == 'metric' else self._name_user_defined
  227. sub_value = self._filtration_result.get(sub_value_name)
  228. if sub_value:
  229. return sub_value.get(sub_key)
  230. return self._filtration_result.get(key)
  231. def _organize_filtration_result(self):
  232. """
  233. Organize filtration result.
  234. Returns:
  235. dict, the filtration result.
  236. """
  237. result = {}
  238. for key, field in FIELD_MAPPING.items():
  239. if field.base_name is not None:
  240. base_attr = getattr(self, field.base_name)
  241. result[key] = base_attr.get(field.sub_name) \
  242. if field.sub_name else base_attr
  243. # add metric into filtration result
  244. result[self._name_metric] = self.metric
  245. result[self._name_user_defined] = self.user_defined
  246. # add dataset_graph into filtration result
  247. result[self._name_dataset_graph] = getattr(self, self._name_dataset_graph)
  248. return result
  249. def _parse_train_lineage(self, train_lineage):
  250. """
  251. Parse train lineage.
  252. Args:
  253. train_lineage (Event): Train lineage.
  254. """
  255. if train_lineage is None:
  256. self._lineage_info[self._name_model] = {}
  257. self._lineage_info[self._name_algorithm] = {}
  258. self._lineage_info[self._name_hyper_parameters] = {}
  259. self._lineage_info[self._name_train_dataset] = {}
  260. return
  261. event_dict = MessageToDict(
  262. train_lineage, preserving_proto_field_name=True
  263. )
  264. train_dict = event_dict.get(self._name_train_lineage)
  265. if train_dict is None:
  266. raise LineageEventFieldNotExistException(
  267. self._name_train_lineage
  268. )
  269. # when MessageToDict is converted to dict, int64 type is converted
  270. # to string, so we convert it to an int in python
  271. if train_dict.get(self._name_model):
  272. model_size = train_dict.get(self._name_model).get('size')
  273. if model_size:
  274. train_dict[self._name_model]['size'] = int(model_size)
  275. self._lineage_info.update(**train_dict)
  276. def _parse_evaluation_lineage(self, evaluation_lineage):
  277. """
  278. Parse evaluation lineage.
  279. Args:
  280. evaluation_lineage (Event): Evaluation lineage.
  281. """
  282. if evaluation_lineage is None:
  283. self._lineage_info[self._name_metric] = {}
  284. self._lineage_info[self._name_valid_dataset] = {}
  285. return
  286. event_dict = MessageToDict(
  287. evaluation_lineage, preserving_proto_field_name=True
  288. )
  289. evaluation_dict = event_dict.get(self._name_evaluation_lineage)
  290. if evaluation_dict is None:
  291. raise LineageEventFieldNotExistException(
  292. self._name_evaluation_lineage
  293. )
  294. self._lineage_info.update(**evaluation_dict)
  295. metric = self._lineage_info.get(self._name_metric)
  296. self._lineage_info[self._name_metric] = json.loads(metric) if metric else {}
  297. def _parse_dataset_graph(self, dataset_graph):
  298. """
  299. Parse dataset graph.
  300. Args:
  301. dataset_graph (Event): Dataset graph.
  302. """
  303. if dataset_graph is None:
  304. self._lineage_info[self._name_dataset_graph] = {}
  305. else:
  306. # convert message to dict
  307. event_dict = organize_graph(dataset_graph.dataset_graph)
  308. if event_dict is None:
  309. raise LineageEventFieldNotExistException(self._name_evaluation_lineage)
  310. self._lineage_info[self._name_dataset_graph] = event_dict if event_dict else {}
  311. def _parse_user_defined_info(self, user_defined_info_list):
  312. """
  313. Parse user defined info.
  314. Args:
  315. user_defined_info_list (list): user defined info list.
  316. """
  317. user_defined_infos = dict()
  318. for user_defined_info in user_defined_info_list:
  319. user_defined_infos.update(user_defined_info)
  320. self._lineage_info[self._name_user_defined] = user_defined_infos

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。