| @@ -22,7 +22,7 @@ EVENT_TRAIN_DICT_0 = { | |||||
| 'train_lineage': { | 'train_lineage': { | ||||
| 'hyper_parameters': { | 'hyper_parameters': { | ||||
| 'optimizer': 'ApplyMomentum0', | 'optimizer': 'ApplyMomentum0', | ||||
| 'learning_rate': 0.10000000149011612, | |||||
| 'learning_rate': 0.11, | |||||
| 'loss_function': '', | 'loss_function': '', | ||||
| 'epoch': 1, | 'epoch': 1, | ||||
| 'parallel_mode': 'stand_alone0', | 'parallel_mode': 'stand_alone0', | ||||
| @@ -31,7 +31,7 @@ EVENT_TRAIN_DICT_0 = { | |||||
| }, | }, | ||||
| 'algorithm': { | 'algorithm': { | ||||
| 'network': 'TrainOneStepCell0', | 'network': 'TrainOneStepCell0', | ||||
| 'loss': 2.3025848865509033 | |||||
| 'loss': 2.3025841 | |||||
| }, | }, | ||||
| 'train_dataset': { | 'train_dataset': { | ||||
| 'train_dataset_path': '', | 'train_dataset_path': '', | ||||
| @@ -49,7 +49,7 @@ EVENT_TRAIN_DICT_1 = { | |||||
| 'train_lineage': { | 'train_lineage': { | ||||
| 'hyper_parameters': { | 'hyper_parameters': { | ||||
| 'optimizer': 'ApplyMomentum1', | 'optimizer': 'ApplyMomentum1', | ||||
| 'learning_rate': 0.20000000298023224, | |||||
| 'learning_rate': 0.2100001, | |||||
| 'loss_function': 'loss_function1', | 'loss_function': 'loss_function1', | ||||
| 'epoch': 1, | 'epoch': 1, | ||||
| 'parallel_mode': 'stand_alone1', | 'parallel_mode': 'stand_alone1', | ||||
| @@ -58,7 +58,7 @@ EVENT_TRAIN_DICT_1 = { | |||||
| }, | }, | ||||
| 'algorithm': { | 'algorithm': { | ||||
| 'network': 'TrainOneStepCell1', | 'network': 'TrainOneStepCell1', | ||||
| 'loss': 2.4025847911834717 | |||||
| 'loss': 2.4025841 | |||||
| }, | }, | ||||
| 'train_dataset': { | 'train_dataset': { | ||||
| 'train_dataset_path': '/path/to/train_dataset1', | 'train_dataset_path': '/path/to/train_dataset1', | ||||
| @@ -76,7 +76,7 @@ EVENT_TRAIN_DICT_2 = { | |||||
| 'train_lineage': { | 'train_lineage': { | ||||
| 'hyper_parameters': { | 'hyper_parameters': { | ||||
| 'optimizer': 'ApplyMomentum2', | 'optimizer': 'ApplyMomentum2', | ||||
| 'learning_rate': 0.30000001192092896, | |||||
| 'learning_rate': 0.3100001, | |||||
| 'loss_function': 'loss_function2', | 'loss_function': 'loss_function2', | ||||
| 'epoch': 2, | 'epoch': 2, | ||||
| 'parallel_mode': 'stand_alone2', | 'parallel_mode': 'stand_alone2', | ||||
| @@ -85,7 +85,7 @@ EVENT_TRAIN_DICT_2 = { | |||||
| }, | }, | ||||
| 'algorithm': { | 'algorithm': { | ||||
| 'network': 'TrainOneStepCell2', | 'network': 'TrainOneStepCell2', | ||||
| 'loss': 2.502584934234619 | |||||
| 'loss': 2.5025841 | |||||
| }, | }, | ||||
| 'train_dataset': { | 'train_dataset': { | ||||
| 'train_dataset_path': '/path/to/train_dataset2', | 'train_dataset_path': '/path/to/train_dataset2', | ||||
| @@ -103,7 +103,7 @@ EVENT_TRAIN_DICT_3 = { | |||||
| 'train_lineage': { | 'train_lineage': { | ||||
| 'hyper_parameters': { | 'hyper_parameters': { | ||||
| 'optimizer': 'ApplyMomentum3', | 'optimizer': 'ApplyMomentum3', | ||||
| 'learning_rate': 0.4000000059604645, | |||||
| 'learning_rate': 0.4, | |||||
| 'loss_function': 'loss_function3', | 'loss_function': 'loss_function3', | ||||
| 'epoch': 2, | 'epoch': 2, | ||||
| 'parallel_mode': 'stand_alone3', | 'parallel_mode': 'stand_alone3', | ||||
| @@ -112,7 +112,7 @@ EVENT_TRAIN_DICT_3 = { | |||||
| }, | }, | ||||
| 'algorithm': { | 'algorithm': { | ||||
| 'network': 'TrainOneStepCell3', | 'network': 'TrainOneStepCell3', | ||||
| 'loss': 2.6025848388671875 | |||||
| 'loss': 2.6025841 | |||||
| }, | }, | ||||
| 'train_dataset': { | 'train_dataset': { | ||||
| 'train_dataset_path': '/path/to/train_dataset3', | 'train_dataset_path': '/path/to/train_dataset3', | ||||
| @@ -139,7 +139,7 @@ EVENT_TRAIN_DICT_4 = { | |||||
| }, | }, | ||||
| 'algorithm': { | 'algorithm': { | ||||
| 'network': 'TrainOneStepCell4', | 'network': 'TrainOneStepCell4', | ||||
| 'loss': 2.702584981918335 | |||||
| 'loss': 2.7025841 | |||||
| }, | }, | ||||
| 'train_dataset': { | 'train_dataset': { | ||||
| 'train_dataset_path': '/path/to/train_dataset4', | 'train_dataset_path': '/path/to/train_dataset4', | ||||
| @@ -166,7 +166,7 @@ EVENT_TRAIN_DICT_5 = { | |||||
| }, | }, | ||||
| 'algorithm': { | 'algorithm': { | ||||
| 'network': 'TrainOneStepCell5', | 'network': 'TrainOneStepCell5', | ||||
| 'loss': 2.702584981918335 | |||||
| 'loss': 2.7025841 | |||||
| }, | }, | ||||
| 'train_dataset': { | 'train_dataset': { | ||||
| 'train_dataset_size': 35 | 'train_dataset_size': 35 | ||||
| @@ -211,33 +211,33 @@ CUSTOMIZED_2 = { | |||||
| } | } | ||||
| METRIC_1 = { | METRIC_1 = { | ||||
| 'accuracy': 1.0000002, | |||||
| 'accuracy': 1.2000002, | |||||
| 'mae': 2.00000002, | 'mae': 2.00000002, | ||||
| 'mse': 3.00000002 | 'mse': 3.00000002 | ||||
| } | } | ||||
| METRIC_2 = { | METRIC_2 = { | ||||
| 'accuracy': 1.0000003, | |||||
| 'mae': 2.00000003, | |||||
| 'mse': 3.00000003 | |||||
| 'accuracy': 1.3000003, | |||||
| 'mae': 2.30000003, | |||||
| 'mse': 3.30000003 | |||||
| } | } | ||||
| METRIC_3 = { | METRIC_3 = { | ||||
| 'accuracy': 1.0000004, | |||||
| 'mae': 2.00000004, | |||||
| 'mse': 3.00000004 | |||||
| 'accuracy': 1.4000004, | |||||
| 'mae': 2.40000004, | |||||
| 'mse': 3.40000004 | |||||
| } | } | ||||
| METRIC_4 = { | METRIC_4 = { | ||||
| 'accuracy': 1.0000005, | |||||
| 'mae': 2.00000005, | |||||
| 'mse': 3.00000005 | |||||
| 'accuracy': 1.5000005, | |||||
| 'mae': 2.50000005, | |||||
| 'mse': 3.50000005 | |||||
| } | } | ||||
| METRIC_5 = { | METRIC_5 = { | ||||
| 'accuracy': 1.0000006, | |||||
| 'mae': 2.00000006, | |||||
| 'mse': 3.00000006 | |||||
| 'accuracy': 1.7000006, | |||||
| 'mae': 2.60000006, | |||||
| 'mse': 3.60000006 | |||||
| } | } | ||||
| EVENT_EVAL_DICT_0 = { | EVENT_EVAL_DICT_0 = { | ||||
| @@ -27,6 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier | |||||
| from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo | from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo | ||||
| from . import event_data | from . import event_data | ||||
| from ....utils.tools import deal_float_for_dict | |||||
| def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): | def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): | ||||
| @@ -266,7 +267,6 @@ class TestQuerier(TestCase): | |||||
| mock_file_handler = MagicMock() | mock_file_handler = MagicMock() | ||||
| mock_file_handler.size = 1 | mock_file_handler.size = 1 | ||||
| args[2].return_value = [{'relative_path': './', 'update_time': 1}] | args[2].return_value = [{'relative_path': './', 'update_time': 1}] | ||||
| single_summary_path = '/path/to/summary0' | single_summary_path = '/path/to/summary0' | ||||
| lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs | lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs | ||||
| @@ -282,17 +282,31 @@ class TestQuerier(TestCase): | |||||
| lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs | lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs | ||||
| self.multi_querier = Querier(lineage_objects) | self.multi_querier = Querier(lineage_objects) | ||||
| def _deal_float_for_list(self, list1, list2): | |||||
| index = 0 | |||||
| for _ in list1: | |||||
| deal_float_for_dict(list1[index], list2[index]) | |||||
| index += 1 | |||||
| def _assert_list_equal(self, list1, list2): | |||||
| self._deal_float_for_list(list1, list2) | |||||
| self.assertListEqual(list1, list2) | |||||
| def _assert_lineages_equal(self, lineages1, lineages2): | |||||
| self._deal_float_for_list(lineages1['object'], lineages2['object']) | |||||
| self.assertDictEqual(lineages1, lineages2) | |||||
| def test_get_summary_lineage_success_1(self): | def test_get_summary_lineage_success_1(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| expected_result = [LINEAGE_INFO_0] | expected_result = [LINEAGE_INFO_0] | ||||
| result = self.single_querier.get_summary_lineage() | result = self.single_querier.get_summary_lineage() | ||||
| self.assertListEqual(expected_result, result) | |||||
| self._assert_list_equal(expected_result, result) | |||||
| def test_get_summary_lineage_success_2(self): | def test_get_summary_lineage_success_2(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| expected_result = [LINEAGE_INFO_0] | expected_result = [LINEAGE_INFO_0] | ||||
| result = self.single_querier.get_summary_lineage() | result = self.single_querier.get_summary_lineage() | ||||
| self.assertListEqual(expected_result, result) | |||||
| self._assert_list_equal(expected_result, result) | |||||
| def test_get_summary_lineage_success_3(self): | def test_get_summary_lineage_success_3(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| @@ -306,7 +320,7 @@ class TestQuerier(TestCase): | |||||
| result = self.single_querier.get_summary_lineage( | result = self.single_querier.get_summary_lineage( | ||||
| filter_keys=['model', 'algorithm'] | filter_keys=['model', 'algorithm'] | ||||
| ) | ) | ||||
| self.assertListEqual(expected_result, result) | |||||
| self._assert_list_equal(expected_result, result) | |||||
| def test_get_summary_lineage_success_4(self): | def test_get_summary_lineage_success_4(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| @@ -353,7 +367,7 @@ class TestQuerier(TestCase): | |||||
| } | } | ||||
| ] | ] | ||||
| result = self.multi_querier.get_summary_lineage() | result = self.multi_querier.get_summary_lineage() | ||||
| self.assertListEqual(expected_result, result) | |||||
| self._assert_list_equal(expected_result, result) | |||||
| def test_get_summary_lineage_success_5(self): | def test_get_summary_lineage_success_5(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| @@ -361,7 +375,7 @@ class TestQuerier(TestCase): | |||||
| result = self.multi_querier.get_summary_lineage( | result = self.multi_querier.get_summary_lineage( | ||||
| summary_dir='/path/to/summary1' | summary_dir='/path/to/summary1' | ||||
| ) | ) | ||||
| self.assertListEqual(expected_result, result) | |||||
| self._assert_list_equal(expected_result, result) | |||||
| def test_get_summary_lineage_success_6(self): | def test_get_summary_lineage_success_6(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| @@ -380,7 +394,7 @@ class TestQuerier(TestCase): | |||||
| result = self.multi_querier.get_summary_lineage( | result = self.multi_querier.get_summary_lineage( | ||||
| summary_dir='/path/to/summary0', filter_keys=filter_keys | summary_dir='/path/to/summary0', filter_keys=filter_keys | ||||
| ) | ) | ||||
| self.assertListEqual(expected_result, result) | |||||
| self._assert_list_equal(expected_result, result) | |||||
| def test_get_summary_lineage_fail(self): | def test_get_summary_lineage_fail(self): | ||||
| """Test the function of get_summary_lineage with exception.""" | """Test the function of get_summary_lineage with exception.""" | ||||
| @@ -423,7 +437,7 @@ class TestQuerier(TestCase): | |||||
| 'count': 2, | 'count': 2, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_lineages_equal(expected_result, result) | |||||
| def test_filter_summary_lineage_success_2(self): | def test_filter_summary_lineage_success_2(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| @@ -448,7 +462,7 @@ class TestQuerier(TestCase): | |||||
| 'count': 2, | 'count': 2, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_lineages_equal(expected_result, result) | |||||
| def test_filter_summary_lineage_success_3(self): | def test_filter_summary_lineage_success_3(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| @@ -465,7 +479,7 @@ class TestQuerier(TestCase): | |||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_lineages_equal(expected_result, result) | |||||
| def test_filter_summary_lineage_success_4(self): | def test_filter_summary_lineage_success_4(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| @@ -483,7 +497,7 @@ class TestQuerier(TestCase): | |||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage() | result = self.multi_querier.filter_summary_lineage() | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_lineages_equal(expected_result, result) | |||||
| def test_filter_summary_lineage_success_5(self): | def test_filter_summary_lineage_success_5(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| @@ -498,7 +512,7 @@ class TestQuerier(TestCase): | |||||
| 'count': 1, | 'count': 1, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_lineages_equal(expected_result, result) | |||||
| def test_filter_summary_lineage_success_6(self): | def test_filter_summary_lineage_success_6(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| @@ -520,7 +534,7 @@ class TestQuerier(TestCase): | |||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_lineages_equal(expected_result, result) | |||||
| def test_filter_summary_lineage_success_7(self): | def test_filter_summary_lineage_success_7(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| @@ -542,14 +556,14 @@ class TestQuerier(TestCase): | |||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_lineages_equal(expected_result, result) | |||||
| def test_filter_summary_lineage_success_8(self): | def test_filter_summary_lineage_success_8(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| condition = { | condition = { | ||||
| 'metric/accuracy': { | 'metric/accuracy': { | ||||
| 'lt': 1.0000006, | |||||
| 'gt': 1.0000004 | |||||
| 'lt': 1.6000006, | |||||
| 'gt': 1.4000004 | |||||
| } | } | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| @@ -558,7 +572,7 @@ class TestQuerier(TestCase): | |||||
| 'count': 1, | 'count': 1, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_lineages_equal(expected_result, result) | |||||
| def test_filter_summary_lineage_success_9(self): | def test_filter_summary_lineage_success_9(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| @@ -572,14 +586,14 @@ class TestQuerier(TestCase): | |||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_lineages_equal(expected_result, result) | |||||
| def test_filter_summary_lineage_fail(self): | def test_filter_summary_lineage_fail(self): | ||||
| """Test the function of filter_summary_lineage with exception.""" | """Test the function of filter_summary_lineage with exception.""" | ||||
| condition = { | condition = { | ||||
| 'xxx': { | 'xxx': { | ||||
| 'lt': 1.0000006, | |||||
| 'gt': 1.0000004 | |||||
| 'lt': 1.6000006, | |||||
| 'gt': 1.4000004 | |||||
| } | } | ||||
| } | } | ||||
| self.assertRaises( | self.assertRaises( | ||||
| @@ -21,6 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj | |||||
| from . import event_data | from . import event_data | ||||
| from .test_querier import create_filtration_result, create_lineage_info | from .test_querier import create_filtration_result, create_lineage_info | ||||
| from ....utils.tools import deal_float_for_dict | |||||
| class TestLineageObj(TestCase): | class TestLineageObj(TestCase): | ||||
| @@ -50,27 +51,31 @@ class TestLineageObj(TestCase): | |||||
| evaluation_lineage=lineage_info.eval_lineage | evaluation_lineage=lineage_info.eval_lineage | ||||
| ) | ) | ||||
| def _assert_dict_equal(self, dict1, dict2): | |||||
| deal_float_for_dict(dict1, dict2) | |||||
| self.assertDictEqual(dict1, dict2) | |||||
| def test_property(self): | def test_property(self): | ||||
| """Test the function of getting property.""" | """Test the function of getting property.""" | ||||
| self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) | self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) | ||||
| self.assertDictEqual( | |||||
| self._assert_dict_equal( | |||||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], | event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], | ||||
| self.lineage_obj.algorithm | self.lineage_obj.algorithm | ||||
| ) | ) | ||||
| self.assertDictEqual( | |||||
| self._assert_dict_equal( | |||||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], | event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], | ||||
| self.lineage_obj.model | self.lineage_obj.model | ||||
| ) | ) | ||||
| self.assertDictEqual( | |||||
| self._assert_dict_equal( | |||||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], | event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], | ||||
| self.lineage_obj.train_dataset | self.lineage_obj.train_dataset | ||||
| ) | ) | ||||
| self.assertDictEqual( | |||||
| self._assert_dict_equal( | |||||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], | event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], | ||||
| self.lineage_obj.hyper_parameters | self.lineage_obj.hyper_parameters | ||||
| ) | ) | ||||
| self.assertDictEqual(event_data.METRIC_0, self.lineage_obj.metric) | |||||
| self.assertDictEqual( | |||||
| self._assert_dict_equal(event_data.METRIC_0, self.lineage_obj.metric) | |||||
| self._assert_dict_equal( | |||||
| event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | ||||
| self.lineage_obj.valid_dataset | self.lineage_obj.valid_dataset | ||||
| ) | ) | ||||
| @@ -78,24 +83,24 @@ class TestLineageObj(TestCase): | |||||
| def test_property_eval_not_exist(self): | def test_property_eval_not_exist(self): | ||||
| """Test the function of getting property with no evaluation event.""" | """Test the function of getting property with no evaluation event.""" | ||||
| self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) | self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) | ||||
| self.assertDictEqual( | |||||
| self._assert_dict_equal( | |||||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], | event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], | ||||
| self.lineage_obj_no_eval.algorithm | self.lineage_obj_no_eval.algorithm | ||||
| ) | ) | ||||
| self.assertDictEqual( | |||||
| self._assert_dict_equal( | |||||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], | event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], | ||||
| self.lineage_obj_no_eval.model | self.lineage_obj_no_eval.model | ||||
| ) | ) | ||||
| self.assertDictEqual( | |||||
| self._assert_dict_equal( | |||||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], | event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], | ||||
| self.lineage_obj_no_eval.train_dataset | self.lineage_obj_no_eval.train_dataset | ||||
| ) | ) | ||||
| self.assertDictEqual( | |||||
| self._assert_dict_equal( | |||||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], | event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], | ||||
| self.lineage_obj_no_eval.hyper_parameters | self.lineage_obj_no_eval.hyper_parameters | ||||
| ) | ) | ||||
| self.assertDictEqual({}, self.lineage_obj_no_eval.metric) | |||||
| self.assertDictEqual({}, self.lineage_obj_no_eval.valid_dataset) | |||||
| self._assert_dict_equal({}, self.lineage_obj_no_eval.metric) | |||||
| self._assert_dict_equal({}, self.lineage_obj_no_eval.valid_dataset) | |||||
| def test_get_summary_info(self): | def test_get_summary_info(self): | ||||
| """Test the function of get_summary_info.""" | """Test the function of get_summary_info.""" | ||||
| @@ -106,7 +111,7 @@ class TestLineageObj(TestCase): | |||||
| 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] | 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] | ||||
| } | } | ||||
| result = self.lineage_obj.get_summary_info(filter_keys) | result = self.lineage_obj.get_summary_info(filter_keys) | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_dict_equal(expected_result, result) | |||||
| def test_to_model_lineage_dict(self): | def test_to_model_lineage_dict(self): | ||||
| """Test the function of to_model_lineage_dict.""" | """Test the function of to_model_lineage_dict.""" | ||||
| @@ -120,7 +125,7 @@ class TestLineageObj(TestCase): | |||||
| expected_result['model_lineage']['dataset_mark'] = None | expected_result['model_lineage']['dataset_mark'] = None | ||||
| expected_result.pop('dataset_graph') | expected_result.pop('dataset_graph') | ||||
| result = self.lineage_obj.to_model_lineage_dict() | result = self.lineage_obj.to_model_lineage_dict() | ||||
| self.assertDictEqual(expected_result, result) | |||||
| self._assert_dict_equal(expected_result, result) | |||||
| def test_to_dataset_lineage_dict(self): | def test_to_dataset_lineage_dict(self): | ||||
| """Test the function of to_dataset_lineage_dict.""" | """Test the function of to_dataset_lineage_dict.""" | ||||
| @@ -81,3 +81,59 @@ def compare_result_with_file(result, expected_file_path): | |||||
| with open(expected_file_path, 'r') as file: | with open(expected_file_path, 'r') as file: | ||||
| expected_results = json.load(file) | expected_results = json.load(file) | ||||
| assert result == expected_results | assert result == expected_results | ||||
| def deal_float_for_dict(res: dict, expected_res: dict): | |||||
| """ | |||||
| Deal float rounded to five decimals in dict. | |||||
| For example: | |||||
| res:{ | |||||
| "model_lineages": { | |||||
| "metric": {"acc": 0.1234561} | |||||
| } | |||||
| } | |||||
| expected_res: | |||||
| { | |||||
| "model_lineages": { | |||||
| "metric": {"acc": 0.1234562} | |||||
| } | |||||
| } | |||||
| After: | |||||
| res:{ | |||||
| "model_lineages": { | |||||
| "metric": {"acc": 0.12346} | |||||
| } | |||||
| } | |||||
| expected_res: | |||||
| { | |||||
| "model_lineages": { | |||||
| "metric": {"acc": 0.12346} | |||||
| } | |||||
| } | |||||
| Args: | |||||
| res (dict): e.g. | |||||
| { | |||||
| "model_lineages": { | |||||
| "metric": {"acc": 0.1234561} | |||||
| } | |||||
| } | |||||
| expected_res (dict): | |||||
| { | |||||
| "model_lineages": { | |||||
| "metric": {"acc": 0.1234562} | |||||
| } | |||||
| } | |||||
| """ | |||||
| decimal_num = 5 | |||||
| for key in res: | |||||
| value = res[key] | |||||
| expected_value = expected_res[key] | |||||
| if isinstance(value, dict): | |||||
| deal_float_for_dict(value, expected_value) | |||||
| elif isinstance(value, float): | |||||
| res[key] = round(value, decimal_num) | |||||
| expected_res[key] = round(expected_value, decimal_num) | |||||