| @@ -22,7 +22,7 @@ EVENT_TRAIN_DICT_0 = { | |||
| 'train_lineage': { | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'ApplyMomentum0', | |||
| 'learning_rate': 0.10000000149011612, | |||
| 'learning_rate': 0.11, | |||
| 'loss_function': '', | |||
| 'epoch': 1, | |||
| 'parallel_mode': 'stand_alone0', | |||
| @@ -31,7 +31,7 @@ EVENT_TRAIN_DICT_0 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell0', | |||
| 'loss': 2.3025848865509033 | |||
| 'loss': 2.3025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '', | |||
| @@ -49,7 +49,7 @@ EVENT_TRAIN_DICT_1 = { | |||
| 'train_lineage': { | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'ApplyMomentum1', | |||
| 'learning_rate': 0.20000000298023224, | |||
| 'learning_rate': 0.2100001, | |||
| 'loss_function': 'loss_function1', | |||
| 'epoch': 1, | |||
| 'parallel_mode': 'stand_alone1', | |||
| @@ -58,7 +58,7 @@ EVENT_TRAIN_DICT_1 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell1', | |||
| 'loss': 2.4025847911834717 | |||
| 'loss': 2.4025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '/path/to/train_dataset1', | |||
| @@ -76,7 +76,7 @@ EVENT_TRAIN_DICT_2 = { | |||
| 'train_lineage': { | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'ApplyMomentum2', | |||
| 'learning_rate': 0.30000001192092896, | |||
| 'learning_rate': 0.3100001, | |||
| 'loss_function': 'loss_function2', | |||
| 'epoch': 2, | |||
| 'parallel_mode': 'stand_alone2', | |||
| @@ -85,7 +85,7 @@ EVENT_TRAIN_DICT_2 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell2', | |||
| 'loss': 2.502584934234619 | |||
| 'loss': 2.5025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '/path/to/train_dataset2', | |||
| @@ -103,7 +103,7 @@ EVENT_TRAIN_DICT_3 = { | |||
| 'train_lineage': { | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'ApplyMomentum3', | |||
| 'learning_rate': 0.4000000059604645, | |||
| 'learning_rate': 0.4, | |||
| 'loss_function': 'loss_function3', | |||
| 'epoch': 2, | |||
| 'parallel_mode': 'stand_alone3', | |||
| @@ -112,7 +112,7 @@ EVENT_TRAIN_DICT_3 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell3', | |||
| 'loss': 2.6025848388671875 | |||
| 'loss': 2.6025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '/path/to/train_dataset3', | |||
| @@ -139,7 +139,7 @@ EVENT_TRAIN_DICT_4 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell4', | |||
| 'loss': 2.702584981918335 | |||
| 'loss': 2.7025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '/path/to/train_dataset4', | |||
| @@ -166,7 +166,7 @@ EVENT_TRAIN_DICT_5 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell5', | |||
| 'loss': 2.702584981918335 | |||
| 'loss': 2.7025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_size': 35 | |||
| @@ -211,33 +211,33 @@ CUSTOMIZED_2 = { | |||
| } | |||
| METRIC_1 = { | |||
| 'accuracy': 1.0000002, | |||
| 'accuracy': 1.2000002, | |||
| 'mae': 2.00000002, | |||
| 'mse': 3.00000002 | |||
| } | |||
| METRIC_2 = { | |||
| 'accuracy': 1.0000003, | |||
| 'mae': 2.00000003, | |||
| 'mse': 3.00000003 | |||
| 'accuracy': 1.3000003, | |||
| 'mae': 2.30000003, | |||
| 'mse': 3.30000003 | |||
| } | |||
| METRIC_3 = { | |||
| 'accuracy': 1.0000004, | |||
| 'mae': 2.00000004, | |||
| 'mse': 3.00000004 | |||
| 'accuracy': 1.4000004, | |||
| 'mae': 2.40000004, | |||
| 'mse': 3.40000004 | |||
| } | |||
| METRIC_4 = { | |||
| 'accuracy': 1.0000005, | |||
| 'mae': 2.00000005, | |||
| 'mse': 3.00000005 | |||
| 'accuracy': 1.5000005, | |||
| 'mae': 2.50000005, | |||
| 'mse': 3.50000005 | |||
| } | |||
| METRIC_5 = { | |||
| 'accuracy': 1.0000006, | |||
| 'mae': 2.00000006, | |||
| 'mse': 3.00000006 | |||
| 'accuracy': 1.7000006, | |||
| 'mae': 2.60000006, | |||
| 'mse': 3.60000006 | |||
| } | |||
| EVENT_EVAL_DICT_0 = { | |||
| @@ -27,6 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier | |||
| from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo | |||
| from . import event_data | |||
| from ....utils.tools import deal_float_for_dict | |||
| def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): | |||
| @@ -266,7 +267,6 @@ class TestQuerier(TestCase): | |||
| mock_file_handler = MagicMock() | |||
| mock_file_handler.size = 1 | |||
| args[2].return_value = [{'relative_path': './', 'update_time': 1}] | |||
| single_summary_path = '/path/to/summary0' | |||
| lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs | |||
| @@ -282,17 +282,31 @@ class TestQuerier(TestCase): | |||
| lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs | |||
| self.multi_querier = Querier(lineage_objects) | |||
| def _deal_float_for_list(self, list1, list2): | |||
| index = 0 | |||
| for _ in list1: | |||
| deal_float_for_dict(list1[index], list2[index]) | |||
| index += 1 | |||
| def _assert_list_equal(self, list1, list2): | |||
| self._deal_float_for_list(list1, list2) | |||
| self.assertListEqual(list1, list2) | |||
| def _assert_lineages_equal(self, lineages1, lineages2): | |||
| self._deal_float_for_list(lineages1['object'], lineages2['object']) | |||
| self.assertDictEqual(lineages1, lineages2) | |||
| def test_get_summary_lineage_success_1(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| expected_result = [LINEAGE_INFO_0] | |||
| result = self.single_querier.get_summary_lineage() | |||
| self.assertListEqual(expected_result, result) | |||
| self._assert_list_equal(expected_result, result) | |||
| def test_get_summary_lineage_success_2(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| expected_result = [LINEAGE_INFO_0] | |||
| result = self.single_querier.get_summary_lineage() | |||
| self.assertListEqual(expected_result, result) | |||
| self._assert_list_equal(expected_result, result) | |||
| def test_get_summary_lineage_success_3(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| @@ -306,7 +320,7 @@ class TestQuerier(TestCase): | |||
| result = self.single_querier.get_summary_lineage( | |||
| filter_keys=['model', 'algorithm'] | |||
| ) | |||
| self.assertListEqual(expected_result, result) | |||
| self._assert_list_equal(expected_result, result) | |||
| def test_get_summary_lineage_success_4(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| @@ -353,7 +367,7 @@ class TestQuerier(TestCase): | |||
| } | |||
| ] | |||
| result = self.multi_querier.get_summary_lineage() | |||
| self.assertListEqual(expected_result, result) | |||
| self._assert_list_equal(expected_result, result) | |||
| def test_get_summary_lineage_success_5(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| @@ -361,7 +375,7 @@ class TestQuerier(TestCase): | |||
| result = self.multi_querier.get_summary_lineage( | |||
| summary_dir='/path/to/summary1' | |||
| ) | |||
| self.assertListEqual(expected_result, result) | |||
| self._assert_list_equal(expected_result, result) | |||
| def test_get_summary_lineage_success_6(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| @@ -380,7 +394,7 @@ class TestQuerier(TestCase): | |||
| result = self.multi_querier.get_summary_lineage( | |||
| summary_dir='/path/to/summary0', filter_keys=filter_keys | |||
| ) | |||
| self.assertListEqual(expected_result, result) | |||
| self._assert_list_equal(expected_result, result) | |||
| def test_get_summary_lineage_fail(self): | |||
| """Test the function of get_summary_lineage with exception.""" | |||
| @@ -423,7 +437,7 @@ class TestQuerier(TestCase): | |||
| 'count': 2, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_lineages_equal(expected_result, result) | |||
| def test_filter_summary_lineage_success_2(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -448,7 +462,7 @@ class TestQuerier(TestCase): | |||
| 'count': 2, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_lineages_equal(expected_result, result) | |||
| def test_filter_summary_lineage_success_3(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -465,7 +479,7 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_lineages_equal(expected_result, result) | |||
| def test_filter_summary_lineage_success_4(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -483,7 +497,7 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage() | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_lineages_equal(expected_result, result) | |||
| def test_filter_summary_lineage_success_5(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -498,7 +512,7 @@ class TestQuerier(TestCase): | |||
| 'count': 1, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_lineages_equal(expected_result, result) | |||
| def test_filter_summary_lineage_success_6(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -520,7 +534,7 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_lineages_equal(expected_result, result) | |||
| def test_filter_summary_lineage_success_7(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -542,14 +556,14 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_lineages_equal(expected_result, result) | |||
| def test_filter_summary_lineage_success_8(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| condition = { | |||
| 'metric/accuracy': { | |||
| 'lt': 1.0000006, | |||
| 'gt': 1.0000004 | |||
| 'lt': 1.6000006, | |||
| 'gt': 1.4000004 | |||
| } | |||
| } | |||
| expected_result = { | |||
| @@ -558,7 +572,7 @@ class TestQuerier(TestCase): | |||
| 'count': 1, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_lineages_equal(expected_result, result) | |||
| def test_filter_summary_lineage_success_9(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -572,14 +586,14 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_lineages_equal(expected_result, result) | |||
| def test_filter_summary_lineage_fail(self): | |||
| """Test the function of filter_summary_lineage with exception.""" | |||
| condition = { | |||
| 'xxx': { | |||
| 'lt': 1.0000006, | |||
| 'gt': 1.0000004 | |||
| 'lt': 1.6000006, | |||
| 'gt': 1.4000004 | |||
| } | |||
| } | |||
| self.assertRaises( | |||
| @@ -21,6 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj | |||
| from . import event_data | |||
| from .test_querier import create_filtration_result, create_lineage_info | |||
| from ....utils.tools import deal_float_for_dict | |||
| class TestLineageObj(TestCase): | |||
| @@ -50,27 +51,31 @@ class TestLineageObj(TestCase): | |||
| evaluation_lineage=lineage_info.eval_lineage | |||
| ) | |||
| def _assert_dict_equal(self, dict1, dict2): | |||
| deal_float_for_dict(dict1, dict2) | |||
| self.assertDictEqual(dict1, dict2) | |||
| def test_property(self): | |||
| """Test the function of getting property.""" | |||
| self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) | |||
| self.assertDictEqual( | |||
| self._assert_dict_equal( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], | |||
| self.lineage_obj.algorithm | |||
| ) | |||
| self.assertDictEqual( | |||
| self._assert_dict_equal( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], | |||
| self.lineage_obj.model | |||
| ) | |||
| self.assertDictEqual( | |||
| self._assert_dict_equal( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], | |||
| self.lineage_obj.train_dataset | |||
| ) | |||
| self.assertDictEqual( | |||
| self._assert_dict_equal( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], | |||
| self.lineage_obj.hyper_parameters | |||
| ) | |||
| self.assertDictEqual(event_data.METRIC_0, self.lineage_obj.metric) | |||
| self.assertDictEqual( | |||
| self._assert_dict_equal(event_data.METRIC_0, self.lineage_obj.metric) | |||
| self._assert_dict_equal( | |||
| event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||
| self.lineage_obj.valid_dataset | |||
| ) | |||
| @@ -78,24 +83,24 @@ class TestLineageObj(TestCase): | |||
| def test_property_eval_not_exist(self): | |||
| """Test the function of getting property with no evaluation event.""" | |||
| self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) | |||
| self.assertDictEqual( | |||
| self._assert_dict_equal( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], | |||
| self.lineage_obj_no_eval.algorithm | |||
| ) | |||
| self.assertDictEqual( | |||
| self._assert_dict_equal( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], | |||
| self.lineage_obj_no_eval.model | |||
| ) | |||
| self.assertDictEqual( | |||
| self._assert_dict_equal( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], | |||
| self.lineage_obj_no_eval.train_dataset | |||
| ) | |||
| self.assertDictEqual( | |||
| self._assert_dict_equal( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], | |||
| self.lineage_obj_no_eval.hyper_parameters | |||
| ) | |||
| self.assertDictEqual({}, self.lineage_obj_no_eval.metric) | |||
| self.assertDictEqual({}, self.lineage_obj_no_eval.valid_dataset) | |||
| self._assert_dict_equal({}, self.lineage_obj_no_eval.metric) | |||
| self._assert_dict_equal({}, self.lineage_obj_no_eval.valid_dataset) | |||
| def test_get_summary_info(self): | |||
| """Test the function of get_summary_info.""" | |||
| @@ -106,7 +111,7 @@ class TestLineageObj(TestCase): | |||
| 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] | |||
| } | |||
| result = self.lineage_obj.get_summary_info(filter_keys) | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_dict_equal(expected_result, result) | |||
| def test_to_model_lineage_dict(self): | |||
| """Test the function of to_model_lineage_dict.""" | |||
| @@ -120,7 +125,7 @@ class TestLineageObj(TestCase): | |||
| expected_result['model_lineage']['dataset_mark'] = None | |||
| expected_result.pop('dataset_graph') | |||
| result = self.lineage_obj.to_model_lineage_dict() | |||
| self.assertDictEqual(expected_result, result) | |||
| self._assert_dict_equal(expected_result, result) | |||
| def test_to_dataset_lineage_dict(self): | |||
| """Test the function of to_dataset_lineage_dict.""" | |||
| @@ -81,3 +81,59 @@ def compare_result_with_file(result, expected_file_path): | |||
| with open(expected_file_path, 'r') as file: | |||
| expected_results = json.load(file) | |||
| assert result == expected_results | |||
| def deal_float_for_dict(res: dict, expected_res: dict): | |||
| """ | |||
| Deal float rounded to five decimals in dict. | |||
| For example: | |||
| res:{ | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.1234561} | |||
| } | |||
| } | |||
| expected_res: | |||
| { | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.1234562} | |||
| } | |||
| } | |||
| After: | |||
| res:{ | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.12346} | |||
| } | |||
| } | |||
| expected_res: | |||
| { | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.12346} | |||
| } | |||
| } | |||
| Args: | |||
| res (dict): e.g. | |||
| { | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.1234561} | |||
| } | |||
| } | |||
| expected_res (dict): | |||
| { | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.1234562} | |||
| } | |||
| } | |||
| """ | |||
| decimal_num = 5 | |||
| for key in res: | |||
| value = res[key] | |||
| expected_value = expected_res[key] | |||
| if isinstance(value, dict): | |||
| deal_float_for_dict(value, expected_value) | |||
| elif isinstance(value, float): | |||
| res[key] = round(value, decimal_num) | |||
| expected_res[key] = round(expected_value, decimal_num) | |||