You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

query_model.py 11 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define lineage info model."""
  16. import json
  17. from collections import namedtuple
  18. from google.protobuf.json_format import MessageToDict
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import \
  20. LineageEventFieldNotExistException, LineageEventNotExistException
  21. from mindinsight.lineagemgr.summary._summary_adapter import organize_graph
  22. Field = namedtuple('Field', ['base_name', 'sub_name'])
  23. FIELD_MAPPING = {
  24. "summary_dir": Field('summary_dir', None),
  25. "loss_function": Field("hyper_parameters", 'loss_function'),
  26. "train_dataset_path": Field('train_dataset', 'train_dataset_path'),
  27. "train_dataset_count": Field("train_dataset", 'train_dataset_size'),
  28. "test_dataset_path": Field('valid_dataset', 'valid_dataset_path'),
  29. "test_dataset_count": Field('valid_dataset', 'valid_dataset_size'),
  30. "network": Field('algorithm', 'network'),
  31. "optimizer": Field('hyper_parameters', 'optimizer'),
  32. "learning_rate": Field('hyper_parameters', 'learning_rate'),
  33. "epoch": Field('hyper_parameters', 'epoch'),
  34. "batch_size": Field('hyper_parameters', 'batch_size'),
  35. "loss": Field('algorithm', 'loss'),
  36. "model_size": Field('model', 'size'),
  37. "dataset_mark": Field('dataset_mark', None),
  38. }
  39. class LineageObj:
  40. """
  41. Lineage information class.
  42. An instance of the class hold lineage information for a training session.
  43. Args:
  44. summary_dir (str): Summary log dir.
  45. kwargs (dict): Params to init the instance.
  46. - train_lineage (Event): Train lineage object.
  47. - evaluation_lineage (Event): Evaluation lineage object.
  48. - dataset_graph (Event): Dataset graph object.
  49. Raises:
  50. LineageEventNotExistException: If train and evaluation event not exist.
  51. LineageEventFieldNotExistException: If the special event field not exist.
  52. """
  53. _name_train_lineage = 'train_lineage'
  54. _name_evaluation_lineage = 'evaluation_lineage'
  55. _name_summary_dir = 'summary_dir'
  56. _name_metric = 'metric'
  57. _name_hyper_parameters = 'hyper_parameters'
  58. _name_algorithm = 'algorithm'
  59. _name_train_dataset = 'train_dataset'
  60. _name_model = 'model'
  61. _name_valid_dataset = 'valid_dataset'
  62. _name_dataset_graph = 'dataset_graph'
  63. _name_dataset_mark = 'dataset_mark'
  64. def __init__(self, summary_dir, **kwargs):
  65. self._lineage_info = {
  66. self._name_summary_dir: summary_dir
  67. }
  68. train_lineage = kwargs.get('train_lineage')
  69. evaluation_lineage = kwargs.get('evaluation_lineage')
  70. dataset_graph = kwargs.get('dataset_graph')
  71. if not any([train_lineage, evaluation_lineage, dataset_graph]):
  72. raise LineageEventNotExistException()
  73. self._parse_train_lineage(train_lineage)
  74. self._parse_evaluation_lineage(evaluation_lineage)
  75. self._parse_dataset_graph(dataset_graph)
  76. self._filtration_result = self._organize_filtration_result()
  77. @property
  78. def summary_dir(self):
  79. """
  80. Get summary log dir.
  81. Returns:
  82. str, the summary log dir.
  83. """
  84. return self._lineage_info.get(self._name_summary_dir)
  85. @property
  86. def metric(self):
  87. """
  88. Get metric information.
  89. Returns:
  90. dict, the metric information.
  91. """
  92. return self._lineage_info.get(self._name_metric)
  93. @property
  94. def hyper_parameters(self):
  95. """
  96. Get hyperparameters.
  97. Returns:
  98. dict, the hyperparameters.
  99. """
  100. return self._lineage_info.get(self._name_hyper_parameters)
  101. @property
  102. def algorithm(self):
  103. """
  104. Get algorithm.
  105. Returns:
  106. dict, the algorithm.
  107. """
  108. return self._lineage_info.get(self._name_algorithm)
  109. @property
  110. def train_dataset(self):
  111. """
  112. Get train dataset information.
  113. Returns:
  114. dict, the train dataset information.
  115. """
  116. return self._lineage_info.get(self._name_train_dataset)
  117. @property
  118. def model(self):
  119. """
  120. Get model information.
  121. Returns:
  122. dict, the model information.
  123. """
  124. return self._lineage_info.get(self._name_model)
  125. @property
  126. def valid_dataset(self):
  127. """
  128. Get valid dataset information.
  129. Returns:
  130. dict, the valid dataset information.
  131. """
  132. return self._lineage_info.get(self._name_valid_dataset)
  133. @property
  134. def dataset_graph(self):
  135. """
  136. Get dataset_graph.
  137. Returns:
  138. dict, the dataset graph information.
  139. """
  140. return self._lineage_info.get(self._name_dataset_graph)
  141. @property
  142. def dataset_mark(self):
  143. """
  144. Get dataset_mark.
  145. Returns:
  146. dict, the dataset mark information.
  147. """
  148. return self._lineage_info.get(self._name_dataset_mark)
  149. @dataset_mark.setter
  150. def dataset_mark(self, dataset_mark):
  151. """
  152. Set dataset mark.
  153. Args:
  154. dataset_mark (int): Dataset mark.
  155. """
  156. self._lineage_info[self._name_dataset_mark] = dataset_mark
  157. # update dataset_mark into filtration result
  158. self._filtration_result[self._name_dataset_mark] = dataset_mark
  159. def get_summary_info(self, filter_keys: list):
  160. """
  161. Get the summary lineage information.
  162. Returns the content corresponding to the specified field in the filter
  163. key. The contents of the filter key include `metric`, `hyper_parameters`,
  164. `algorithm`, `train_dataset`, `valid_dataset` and `model`. You can
  165. specify multiple filter keys in the `filter_keys`
  166. Args:
  167. filter_keys (list): Filter keys.
  168. Returns:
  169. dict, the summary lineage information.
  170. """
  171. result = {
  172. self._name_summary_dir: self.summary_dir,
  173. }
  174. for key in filter_keys:
  175. result[key] = getattr(self, key)
  176. return result
  177. def to_filtration_dict(self):
  178. """
  179. Returns the lineage information required by filtering interface.
  180. Returns:
  181. dict, the lineage information required by filtering interface.
  182. """
  183. return self._filtration_result
  184. def to_dataset_lineage_dict(self):
  185. """
  186. Returns the dataset part lineage information.
  187. Returns:
  188. dict, the dataset lineage information.
  189. """
  190. dataset_lineage = {
  191. key: self._filtration_result.get(key)
  192. for key in [self._name_summary_dir, self._name_dataset_graph]
  193. }
  194. return dataset_lineage
  195. def get_value_by_key(self, key):
  196. """
  197. Get the value based on the key in `FIELD_MAPPING` or the key prefixed with `metric_`.
  198. Args:
  199. key (str): The key in `FIELD_MAPPING` or prefixed with `metric_`.
  200. Returns:
  201. object, the value.
  202. """
  203. if key.startswith('metric_'):
  204. metric_key = key.split('_', 1)[1]
  205. metric = self._filtration_result.get(self._name_metric)
  206. if metric:
  207. return metric.get(metric_key)
  208. return self._filtration_result.get(key)
  209. def _organize_filtration_result(self):
  210. """
  211. Organize filtration result.
  212. Returns:
  213. dict, the filtration result.
  214. """
  215. result = {}
  216. for key, field in FIELD_MAPPING.items():
  217. if field.base_name is not None:
  218. base_attr = getattr(self, field.base_name)
  219. result[key] = base_attr.get(field.sub_name) \
  220. if field.sub_name else base_attr
  221. # add metric into filtration result
  222. result[self._name_metric] = self.metric
  223. # add dataset_graph into filtration result
  224. result[self._name_dataset_graph] = getattr(self, self._name_dataset_graph)
  225. return result
  226. def _parse_train_lineage(self, train_lineage):
  227. """
  228. Parse train lineage.
  229. Args:
  230. train_lineage (Event): Train lineage.
  231. """
  232. if train_lineage is None:
  233. self._lineage_info[self._name_model] = {}
  234. self._lineage_info[self._name_algorithm] = {}
  235. self._lineage_info[self._name_hyper_parameters] = {}
  236. self._lineage_info[self._name_train_dataset] = {}
  237. return
  238. event_dict = MessageToDict(
  239. train_lineage, preserving_proto_field_name=True
  240. )
  241. train_dict = event_dict.get(self._name_train_lineage)
  242. if train_dict is None:
  243. raise LineageEventFieldNotExistException(
  244. self._name_train_lineage
  245. )
  246. # when MessageToDict is converted to dict, int64 type is converted
  247. # to string, so we convert it to an int in python
  248. if train_dict.get(self._name_model):
  249. model_size = train_dict.get(self._name_model).get('size')
  250. if model_size:
  251. train_dict[self._name_model]['size'] = int(model_size)
  252. self._lineage_info.update(**train_dict)
  253. def _parse_evaluation_lineage(self, evaluation_lineage):
  254. """
  255. Parse evaluation lineage.
  256. Args:
  257. evaluation_lineage (Event): Evaluation lineage.
  258. """
  259. if evaluation_lineage is None:
  260. self._lineage_info[self._name_metric] = {}
  261. self._lineage_info[self._name_valid_dataset] = {}
  262. return
  263. event_dict = MessageToDict(
  264. evaluation_lineage, preserving_proto_field_name=True
  265. )
  266. evaluation_dict = event_dict.get(self._name_evaluation_lineage)
  267. if evaluation_dict is None:
  268. raise LineageEventFieldNotExistException(
  269. self._name_evaluation_lineage
  270. )
  271. self._lineage_info.update(**evaluation_dict)
  272. metric = self._lineage_info.get(self._name_metric)
  273. self._lineage_info[self._name_metric] = json.loads(metric) if metric else {}
  274. def _parse_dataset_graph(self, dataset_graph):
  275. """
  276. Parse dataset graph.
  277. Args:
  278. dataset_graph (Event): Dataset graph.
  279. """
  280. if dataset_graph is None:
  281. self._lineage_info[self._name_dataset_graph] = {}
  282. else:
  283. # convert message to dict
  284. event_dict = organize_graph(dataset_graph.dataset_graph)
  285. if event_dict is None:
  286. raise LineageEventFieldNotExistException(self._name_evaluation_lineage)
  287. self._lineage_info[self._name_dataset_graph] = event_dict if event_dict else {}

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。

Contributors (1)