You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_query_model.py 7.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test the query_model module."""
  16. from unittest import TestCase
  17. from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageEventFieldNotExistException,
  18. LineageEventNotExistException)
  19. from mindinsight.lineagemgr.querier.query_model import LineageObj
  20. from . import event_data
  21. from .test_querier import create_filtration_result, create_lineage_info
  22. from ....utils.tools import assert_equal_lineages
  23. class TestLineageObj(TestCase):
  24. """Test the class of `LineageObj`."""
  25. def setUp(self):
  26. """Initialization before test case execution."""
  27. lineage_info = create_lineage_info(
  28. event_data.EVENT_TRAIN_DICT_0,
  29. event_data.EVENT_EVAL_DICT_0,
  30. event_data.EVENT_DATASET_DICT_0
  31. )
  32. self.summary_dir = '/path/to/summary0'
  33. self.lineage_obj = LineageObj(
  34. self.summary_dir,
  35. train_lineage=lineage_info.train_lineage,
  36. evaluation_lineage=lineage_info.eval_lineage,
  37. dataset_graph=lineage_info.dataset_graph,
  38. )
  39. lineage_info = create_lineage_info(
  40. event_data.EVENT_TRAIN_DICT_0,
  41. None, None)
  42. self.lineage_obj_no_eval = LineageObj(
  43. self.summary_dir,
  44. train_lineage=lineage_info.train_lineage,
  45. evaluation_lineage=lineage_info.eval_lineage
  46. )
  47. def test_property(self):
  48. """Test the function of getting property."""
  49. self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
  50. assert_equal_lineages(
  51. event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
  52. self.lineage_obj.algorithm,
  53. self.assertDictEqual
  54. )
  55. assert_equal_lineages(
  56. event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
  57. self.lineage_obj.model,
  58. self.assertDictEqual
  59. )
  60. assert_equal_lineages(
  61. event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
  62. self.lineage_obj.train_dataset,
  63. self.assertDictEqual
  64. )
  65. assert_equal_lineages(
  66. event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
  67. self.lineage_obj.hyper_parameters,
  68. self.assertDictEqual
  69. )
  70. assert_equal_lineages(
  71. event_data.METRIC_0,
  72. self.lineage_obj.metric,
  73. self.assertDictEqual
  74. )
  75. assert_equal_lineages(
  76. event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
  77. self.lineage_obj.valid_dataset,
  78. self.assertDictEqual
  79. )
  80. def test_property_eval_not_exist(self):
  81. """Test the function of getting property with no evaluation event."""
  82. self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
  83. assert_equal_lineages(
  84. event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
  85. self.lineage_obj_no_eval.algorithm,
  86. self.assertDictEqual
  87. )
  88. assert_equal_lineages(
  89. event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
  90. self.lineage_obj_no_eval.model,
  91. self.assertDictEqual
  92. )
  93. assert_equal_lineages(
  94. event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
  95. self.lineage_obj_no_eval.train_dataset,
  96. self.assertDictEqual
  97. )
  98. assert_equal_lineages(
  99. event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
  100. self.lineage_obj_no_eval.hyper_parameters,
  101. self.assertDictEqual
  102. )
  103. assert_equal_lineages({}, self.lineage_obj_no_eval.metric, self.assertDictEqual)
  104. assert_equal_lineages({}, self.lineage_obj_no_eval.valid_dataset, self.assertDictEqual)
  105. def test_get_summary_info(self):
  106. """Test the function of get_summary_info."""
  107. filter_keys = ['algorithm', 'model']
  108. expected_result = {
  109. 'summary_dir': self.summary_dir,
  110. 'algorithm': event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
  111. 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']
  112. }
  113. result = self.lineage_obj.get_summary_info(filter_keys)
  114. assert_equal_lineages(expected_result, result, self.assertDictEqual)
  115. def test_to_model_lineage_dict(self):
  116. """Test the function of to_model_lineage_dict."""
  117. expected_result = create_filtration_result(
  118. self.summary_dir,
  119. event_data.EVENT_TRAIN_DICT_0,
  120. event_data.EVENT_EVAL_DICT_0,
  121. event_data.METRIC_0,
  122. event_data.DATASET_DICT_0
  123. )
  124. expected_result['model_lineage']['dataset_mark'] = None
  125. expected_result.pop('dataset_graph')
  126. result = self.lineage_obj.to_model_lineage_dict()
  127. assert_equal_lineages(expected_result, result, self.assertDictEqual)
  128. def test_to_dataset_lineage_dict(self):
  129. """Test the function of to_dataset_lineage_dict."""
  130. expected_result = {
  131. "summary_dir": self.summary_dir,
  132. "dataset_graph": event_data.DATASET_DICT_0
  133. }
  134. result = self.lineage_obj.to_dataset_lineage_dict()
  135. self.assertDictEqual(expected_result, result)
  136. def test_get_value_by_key(self):
  137. """Test the function of get_value_by_key."""
  138. result = self.lineage_obj.get_value_by_key('model_size')
  139. self.assertEqual(
  140. event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']['size'],
  141. result
  142. )
  143. def test_init_fail(self):
  144. """Test the function of init with exception."""
  145. with self.assertRaises(LineageEventNotExistException):
  146. LineageObj(self.summary_dir)
  147. lineage_info = create_lineage_info(
  148. event_data.EVENT_TRAIN_DICT_EXCEPTION, None, None
  149. )
  150. with self.assertRaises(LineageEventFieldNotExistException):
  151. self.lineage_obj = LineageObj(
  152. self.summary_dir,
  153. train_lineage=lineage_info.train_lineage,
  154. evaluation_lineage=lineage_info.eval_lineage
  155. )
  156. lineage_info = create_lineage_info(
  157. event_data.EVENT_TRAIN_DICT_0,
  158. event_data.EVENT_EVAL_DICT_EXCEPTION,
  159. event_data.EVENT_DATASET_DICT_0
  160. )
  161. with self.assertRaises(LineageEventFieldNotExistException):
  162. self.lineage_obj = LineageObj(
  163. self.summary_dir,
  164. train_lineage=lineage_info.train_lineage,
  165. evaluation_lineage=lineage_info.eval_lineage
  166. )