Browse Source

fix the float compare because ci env update

pull/176/head
luopengting 5 years ago
parent
commit
618f8c8ccf
4 changed files with 132 additions and 57 deletions
  1. +23
    -23
      tests/ut/lineagemgr/querier/event_data.py
  2. +34
    -20
      tests/ut/lineagemgr/querier/test_querier.py
  3. +19
    -14
      tests/ut/lineagemgr/querier/test_query_model.py
  4. +56
    -0
      tests/utils/tools.py

+ 23
- 23
tests/ut/lineagemgr/querier/event_data.py View File

@@ -22,7 +22,7 @@ EVENT_TRAIN_DICT_0 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum0', 'optimizer': 'ApplyMomentum0',
'learning_rate': 0.10000000149011612,
'learning_rate': 0.11,
'loss_function': '', 'loss_function': '',
'epoch': 1, 'epoch': 1,
'parallel_mode': 'stand_alone0', 'parallel_mode': 'stand_alone0',
@@ -31,7 +31,7 @@ EVENT_TRAIN_DICT_0 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell0', 'network': 'TrainOneStepCell0',
'loss': 2.3025848865509033
'loss': 2.3025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '', 'train_dataset_path': '',
@@ -49,7 +49,7 @@ EVENT_TRAIN_DICT_1 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum1', 'optimizer': 'ApplyMomentum1',
'learning_rate': 0.20000000298023224,
'learning_rate': 0.2100001,
'loss_function': 'loss_function1', 'loss_function': 'loss_function1',
'epoch': 1, 'epoch': 1,
'parallel_mode': 'stand_alone1', 'parallel_mode': 'stand_alone1',
@@ -58,7 +58,7 @@ EVENT_TRAIN_DICT_1 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell1', 'network': 'TrainOneStepCell1',
'loss': 2.4025847911834717
'loss': 2.4025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset1', 'train_dataset_path': '/path/to/train_dataset1',
@@ -76,7 +76,7 @@ EVENT_TRAIN_DICT_2 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum2', 'optimizer': 'ApplyMomentum2',
'learning_rate': 0.30000001192092896,
'learning_rate': 0.3100001,
'loss_function': 'loss_function2', 'loss_function': 'loss_function2',
'epoch': 2, 'epoch': 2,
'parallel_mode': 'stand_alone2', 'parallel_mode': 'stand_alone2',
@@ -85,7 +85,7 @@ EVENT_TRAIN_DICT_2 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell2', 'network': 'TrainOneStepCell2',
'loss': 2.502584934234619
'loss': 2.5025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset2', 'train_dataset_path': '/path/to/train_dataset2',
@@ -103,7 +103,7 @@ EVENT_TRAIN_DICT_3 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum3', 'optimizer': 'ApplyMomentum3',
'learning_rate': 0.4000000059604645,
'learning_rate': 0.4,
'loss_function': 'loss_function3', 'loss_function': 'loss_function3',
'epoch': 2, 'epoch': 2,
'parallel_mode': 'stand_alone3', 'parallel_mode': 'stand_alone3',
@@ -112,7 +112,7 @@ EVENT_TRAIN_DICT_3 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell3', 'network': 'TrainOneStepCell3',
'loss': 2.6025848388671875
'loss': 2.6025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset3', 'train_dataset_path': '/path/to/train_dataset3',
@@ -139,7 +139,7 @@ EVENT_TRAIN_DICT_4 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell4', 'network': 'TrainOneStepCell4',
'loss': 2.702584981918335
'loss': 2.7025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset4', 'train_dataset_path': '/path/to/train_dataset4',
@@ -166,7 +166,7 @@ EVENT_TRAIN_DICT_5 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell5', 'network': 'TrainOneStepCell5',
'loss': 2.702584981918335
'loss': 2.7025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_size': 35 'train_dataset_size': 35
@@ -211,33 +211,33 @@ CUSTOMIZED_2 = {
} }


METRIC_1 = { METRIC_1 = {
'accuracy': 1.0000002,
'accuracy': 1.2000002,
'mae': 2.00000002, 'mae': 2.00000002,
'mse': 3.00000002 'mse': 3.00000002
} }


METRIC_2 = { METRIC_2 = {
'accuracy': 1.0000003,
'mae': 2.00000003,
'mse': 3.00000003
'accuracy': 1.3000003,
'mae': 2.30000003,
'mse': 3.30000003
} }


METRIC_3 = { METRIC_3 = {
'accuracy': 1.0000004,
'mae': 2.00000004,
'mse': 3.00000004
'accuracy': 1.4000004,
'mae': 2.40000004,
'mse': 3.40000004
} }


METRIC_4 = { METRIC_4 = {
'accuracy': 1.0000005,
'mae': 2.00000005,
'mse': 3.00000005
'accuracy': 1.5000005,
'mae': 2.50000005,
'mse': 3.50000005
} }


METRIC_5 = { METRIC_5 = {
'accuracy': 1.0000006,
'mae': 2.00000006,
'mse': 3.00000006
'accuracy': 1.7000006,
'mae': 2.60000006,
'mse': 3.60000006
} }


EVENT_EVAL_DICT_0 = { EVENT_EVAL_DICT_0 = {


+ 34
- 20
tests/ut/lineagemgr/querier/test_querier.py View File

@@ -27,6 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo


from . import event_data from . import event_data
from ....utils.tools import deal_float_for_dict




def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict):
@@ -266,7 +267,6 @@ class TestQuerier(TestCase):
mock_file_handler = MagicMock() mock_file_handler = MagicMock()
mock_file_handler.size = 1 mock_file_handler.size = 1



args[2].return_value = [{'relative_path': './', 'update_time': 1}] args[2].return_value = [{'relative_path': './', 'update_time': 1}]
single_summary_path = '/path/to/summary0' single_summary_path = '/path/to/summary0'
lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs
@@ -282,17 +282,31 @@ class TestQuerier(TestCase):
lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs
self.multi_querier = Querier(lineage_objects) self.multi_querier = Querier(lineage_objects)


def _deal_float_for_list(self, list1, list2):
index = 0
for _ in list1:
deal_float_for_dict(list1[index], list2[index])
index += 1

def _assert_list_equal(self, list1, list2):
self._deal_float_for_list(list1, list2)
self.assertListEqual(list1, list2)

def _assert_lineages_equal(self, lineages1, lineages2):
self._deal_float_for_list(lineages1['object'], lineages2['object'])
self.assertDictEqual(lineages1, lineages2)

def test_get_summary_lineage_success_1(self): def test_get_summary_lineage_success_1(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)


def test_get_summary_lineage_success_2(self): def test_get_summary_lineage_success_2(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)


def test_get_summary_lineage_success_3(self): def test_get_summary_lineage_success_3(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -306,7 +320,7 @@ class TestQuerier(TestCase):
result = self.single_querier.get_summary_lineage( result = self.single_querier.get_summary_lineage(
filter_keys=['model', 'algorithm'] filter_keys=['model', 'algorithm']
) )
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)


def test_get_summary_lineage_success_4(self): def test_get_summary_lineage_success_4(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -353,7 +367,7 @@ class TestQuerier(TestCase):
} }
] ]
result = self.multi_querier.get_summary_lineage() result = self.multi_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)


def test_get_summary_lineage_success_5(self): def test_get_summary_lineage_success_5(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -361,7 +375,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary1' summary_dir='/path/to/summary1'
) )
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)


def test_get_summary_lineage_success_6(self): def test_get_summary_lineage_success_6(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -380,7 +394,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary0', filter_keys=filter_keys summary_dir='/path/to/summary0', filter_keys=filter_keys
) )
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)


def test_get_summary_lineage_fail(self): def test_get_summary_lineage_fail(self):
"""Test the function of get_summary_lineage with exception.""" """Test the function of get_summary_lineage with exception."""
@@ -423,7 +437,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)


def test_filter_summary_lineage_success_2(self): def test_filter_summary_lineage_success_2(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -448,7 +462,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)


def test_filter_summary_lineage_success_3(self): def test_filter_summary_lineage_success_3(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -465,7 +479,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)


def test_filter_summary_lineage_success_4(self): def test_filter_summary_lineage_success_4(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -483,7 +497,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage() result = self.multi_querier.filter_summary_lineage()
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)


def test_filter_summary_lineage_success_5(self): def test_filter_summary_lineage_success_5(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -498,7 +512,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)


def test_filter_summary_lineage_success_6(self): def test_filter_summary_lineage_success_6(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -520,7 +534,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)


def test_filter_summary_lineage_success_7(self): def test_filter_summary_lineage_success_7(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -542,14 +556,14 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)


def test_filter_summary_lineage_success_8(self): def test_filter_summary_lineage_success_8(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
condition = { condition = {
'metric/accuracy': { 'metric/accuracy': {
'lt': 1.0000006,
'gt': 1.0000004
'lt': 1.6000006,
'gt': 1.4000004
} }
} }
expected_result = { expected_result = {
@@ -558,7 +572,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)


def test_filter_summary_lineage_success_9(self): def test_filter_summary_lineage_success_9(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -572,14 +586,14 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)


def test_filter_summary_lineage_fail(self): def test_filter_summary_lineage_fail(self):
"""Test the function of filter_summary_lineage with exception.""" """Test the function of filter_summary_lineage with exception."""
condition = { condition = {
'xxx': { 'xxx': {
'lt': 1.0000006,
'gt': 1.0000004
'lt': 1.6000006,
'gt': 1.4000004
} }
} }
self.assertRaises( self.assertRaises(


+ 19
- 14
tests/ut/lineagemgr/querier/test_query_model.py View File

@@ -21,6 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj


from . import event_data from . import event_data
from .test_querier import create_filtration_result, create_lineage_info from .test_querier import create_filtration_result, create_lineage_info
from ....utils.tools import deal_float_for_dict




class TestLineageObj(TestCase): class TestLineageObj(TestCase):
@@ -50,27 +51,31 @@ class TestLineageObj(TestCase):
evaluation_lineage=lineage_info.eval_lineage evaluation_lineage=lineage_info.eval_lineage
) )


def _assert_dict_equal(self, dict1, dict2):
deal_float_for_dict(dict1, dict2)
self.assertDictEqual(dict1, dict2)

def test_property(self): def test_property(self):
"""Test the function of getting property.""" """Test the function of getting property."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj.algorithm self.lineage_obj.algorithm
) )
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj.model self.lineage_obj.model
) )
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj.train_dataset self.lineage_obj.train_dataset
) )
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj.hyper_parameters self.lineage_obj.hyper_parameters
) )
self.assertDictEqual(event_data.METRIC_0, self.lineage_obj.metric)
self.assertDictEqual(
self._assert_dict_equal(event_data.METRIC_0, self.lineage_obj.metric)
self._assert_dict_equal(
event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
self.lineage_obj.valid_dataset self.lineage_obj.valid_dataset
) )
@@ -78,24 +83,24 @@ class TestLineageObj(TestCase):
def test_property_eval_not_exist(self): def test_property_eval_not_exist(self):
"""Test the function of getting property with no evaluation event.""" """Test the function of getting property with no evaluation event."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj_no_eval.algorithm self.lineage_obj_no_eval.algorithm
) )
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj_no_eval.model self.lineage_obj_no_eval.model
) )
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj_no_eval.train_dataset self.lineage_obj_no_eval.train_dataset
) )
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj_no_eval.hyper_parameters self.lineage_obj_no_eval.hyper_parameters
) )
self.assertDictEqual({}, self.lineage_obj_no_eval.metric)
self.assertDictEqual({}, self.lineage_obj_no_eval.valid_dataset)
self._assert_dict_equal({}, self.lineage_obj_no_eval.metric)
self._assert_dict_equal({}, self.lineage_obj_no_eval.valid_dataset)


def test_get_summary_info(self): def test_get_summary_info(self):
"""Test the function of get_summary_info.""" """Test the function of get_summary_info."""
@@ -106,7 +111,7 @@ class TestLineageObj(TestCase):
'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']
} }
result = self.lineage_obj.get_summary_info(filter_keys) result = self.lineage_obj.get_summary_info(filter_keys)
self.assertDictEqual(expected_result, result)
self._assert_dict_equal(expected_result, result)


def test_to_model_lineage_dict(self): def test_to_model_lineage_dict(self):
"""Test the function of to_model_lineage_dict.""" """Test the function of to_model_lineage_dict."""
@@ -120,7 +125,7 @@ class TestLineageObj(TestCase):
expected_result['model_lineage']['dataset_mark'] = None expected_result['model_lineage']['dataset_mark'] = None
expected_result.pop('dataset_graph') expected_result.pop('dataset_graph')
result = self.lineage_obj.to_model_lineage_dict() result = self.lineage_obj.to_model_lineage_dict()
self.assertDictEqual(expected_result, result)
self._assert_dict_equal(expected_result, result)


def test_to_dataset_lineage_dict(self): def test_to_dataset_lineage_dict(self):
"""Test the function of to_dataset_lineage_dict.""" """Test the function of to_dataset_lineage_dict."""


+ 56
- 0
tests/utils/tools.py View File

@@ -81,3 +81,59 @@ def compare_result_with_file(result, expected_file_path):
with open(expected_file_path, 'r') as file: with open(expected_file_path, 'r') as file:
expected_results = json.load(file) expected_results = json.load(file)
assert result == expected_results assert result == expected_results


def deal_float_for_dict(res: dict, expected_res: dict):
"""
Deal float rounded to five decimals in dict.

For example:
res:{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}
After:
res:{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}

Args:
res (dict): e.g.
{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res (dict):
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}


"""
decimal_num = 5
for key in res:
value = res[key]
expected_value = expected_res[key]
if isinstance(value, dict):
deal_float_for_dict(value, expected_value)
elif isinstance(value, float):
res[key] = round(value, decimal_num)
expected_res[key] = round(expected_value, decimal_num)

Loading…
Cancel
Save