Browse Source

fix the float compare because ci env update

pull/176/head
luopengting 5 years ago
parent
commit
618f8c8ccf
4 changed files with 132 additions and 57 deletions
  1. +23
    -23
      tests/ut/lineagemgr/querier/event_data.py
  2. +34
    -20
      tests/ut/lineagemgr/querier/test_querier.py
  3. +19
    -14
      tests/ut/lineagemgr/querier/test_query_model.py
  4. +56
    -0
      tests/utils/tools.py

+ 23
- 23
tests/ut/lineagemgr/querier/event_data.py View File

@@ -22,7 +22,7 @@ EVENT_TRAIN_DICT_0 = {
'train_lineage': {
'hyper_parameters': {
'optimizer': 'ApplyMomentum0',
'learning_rate': 0.10000000149011612,
'learning_rate': 0.11,
'loss_function': '',
'epoch': 1,
'parallel_mode': 'stand_alone0',
@@ -31,7 +31,7 @@ EVENT_TRAIN_DICT_0 = {
},
'algorithm': {
'network': 'TrainOneStepCell0',
'loss': 2.3025848865509033
'loss': 2.3025841
},
'train_dataset': {
'train_dataset_path': '',
@@ -49,7 +49,7 @@ EVENT_TRAIN_DICT_1 = {
'train_lineage': {
'hyper_parameters': {
'optimizer': 'ApplyMomentum1',
'learning_rate': 0.20000000298023224,
'learning_rate': 0.2100001,
'loss_function': 'loss_function1',
'epoch': 1,
'parallel_mode': 'stand_alone1',
@@ -58,7 +58,7 @@ EVENT_TRAIN_DICT_1 = {
},
'algorithm': {
'network': 'TrainOneStepCell1',
'loss': 2.4025847911834717
'loss': 2.4025841
},
'train_dataset': {
'train_dataset_path': '/path/to/train_dataset1',
@@ -76,7 +76,7 @@ EVENT_TRAIN_DICT_2 = {
'train_lineage': {
'hyper_parameters': {
'optimizer': 'ApplyMomentum2',
'learning_rate': 0.30000001192092896,
'learning_rate': 0.3100001,
'loss_function': 'loss_function2',
'epoch': 2,
'parallel_mode': 'stand_alone2',
@@ -85,7 +85,7 @@ EVENT_TRAIN_DICT_2 = {
},
'algorithm': {
'network': 'TrainOneStepCell2',
'loss': 2.502584934234619
'loss': 2.5025841
},
'train_dataset': {
'train_dataset_path': '/path/to/train_dataset2',
@@ -103,7 +103,7 @@ EVENT_TRAIN_DICT_3 = {
'train_lineage': {
'hyper_parameters': {
'optimizer': 'ApplyMomentum3',
'learning_rate': 0.4000000059604645,
'learning_rate': 0.4,
'loss_function': 'loss_function3',
'epoch': 2,
'parallel_mode': 'stand_alone3',
@@ -112,7 +112,7 @@ EVENT_TRAIN_DICT_3 = {
},
'algorithm': {
'network': 'TrainOneStepCell3',
'loss': 2.6025848388671875
'loss': 2.6025841
},
'train_dataset': {
'train_dataset_path': '/path/to/train_dataset3',
@@ -139,7 +139,7 @@ EVENT_TRAIN_DICT_4 = {
},
'algorithm': {
'network': 'TrainOneStepCell4',
'loss': 2.702584981918335
'loss': 2.7025841
},
'train_dataset': {
'train_dataset_path': '/path/to/train_dataset4',
@@ -166,7 +166,7 @@ EVENT_TRAIN_DICT_5 = {
},
'algorithm': {
'network': 'TrainOneStepCell5',
'loss': 2.702584981918335
'loss': 2.7025841
},
'train_dataset': {
'train_dataset_size': 35
@@ -211,33 +211,33 @@ CUSTOMIZED_2 = {
}

METRIC_1 = {
'accuracy': 1.0000002,
'accuracy': 1.2000002,
'mae': 2.00000002,
'mse': 3.00000002
}

METRIC_2 = {
'accuracy': 1.0000003,
'mae': 2.00000003,
'mse': 3.00000003
'accuracy': 1.3000003,
'mae': 2.30000003,
'mse': 3.30000003
}

METRIC_3 = {
'accuracy': 1.0000004,
'mae': 2.00000004,
'mse': 3.00000004
'accuracy': 1.4000004,
'mae': 2.40000004,
'mse': 3.40000004
}

METRIC_4 = {
'accuracy': 1.0000005,
'mae': 2.00000005,
'mse': 3.00000005
'accuracy': 1.5000005,
'mae': 2.50000005,
'mse': 3.50000005
}

METRIC_5 = {
'accuracy': 1.0000006,
'mae': 2.00000006,
'mse': 3.00000006
'accuracy': 1.7000006,
'mae': 2.60000006,
'mse': 3.60000006
}

EVENT_EVAL_DICT_0 = {


+ 34
- 20
tests/ut/lineagemgr/querier/test_querier.py View File

@@ -27,6 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo

from . import event_data
from ....utils.tools import deal_float_for_dict


def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict):
@@ -266,7 +267,6 @@ class TestQuerier(TestCase):
mock_file_handler = MagicMock()
mock_file_handler.size = 1


args[2].return_value = [{'relative_path': './', 'update_time': 1}]
single_summary_path = '/path/to/summary0'
lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs
@@ -282,17 +282,31 @@ class TestQuerier(TestCase):
lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs
self.multi_querier = Querier(lineage_objects)

def _deal_float_for_list(self, list1, list2):
index = 0
for _ in list1:
deal_float_for_dict(list1[index], list2[index])
index += 1

def _assert_list_equal(self, list1, list2):
self._deal_float_for_list(list1, list2)
self.assertListEqual(list1, list2)

def _assert_lineages_equal(self, lineages1, lineages2):
self._deal_float_for_list(lineages1['object'], lineages2['object'])
self.assertDictEqual(lineages1, lineages2)

def test_get_summary_lineage_success_1(self):
"""Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)

def test_get_summary_lineage_success_2(self):
"""Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)

def test_get_summary_lineage_success_3(self):
"""Test the success of get_summary_lineage."""
@@ -306,7 +320,7 @@ class TestQuerier(TestCase):
result = self.single_querier.get_summary_lineage(
filter_keys=['model', 'algorithm']
)
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)

def test_get_summary_lineage_success_4(self):
"""Test the success of get_summary_lineage."""
@@ -353,7 +367,7 @@ class TestQuerier(TestCase):
}
]
result = self.multi_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)

def test_get_summary_lineage_success_5(self):
"""Test the success of get_summary_lineage."""
@@ -361,7 +375,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary1'
)
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)

def test_get_summary_lineage_success_6(self):
"""Test the success of get_summary_lineage."""
@@ -380,7 +394,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary0', filter_keys=filter_keys
)
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)

def test_get_summary_lineage_fail(self):
"""Test the function of get_summary_lineage with exception."""
@@ -423,7 +437,7 @@ class TestQuerier(TestCase):
'count': 2,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)

def test_filter_summary_lineage_success_2(self):
"""Test the success of filter_summary_lineage."""
@@ -448,7 +462,7 @@ class TestQuerier(TestCase):
'count': 2,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)

def test_filter_summary_lineage_success_3(self):
"""Test the success of filter_summary_lineage."""
@@ -465,7 +479,7 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)

def test_filter_summary_lineage_success_4(self):
"""Test the success of filter_summary_lineage."""
@@ -483,7 +497,7 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage()
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)

def test_filter_summary_lineage_success_5(self):
"""Test the success of filter_summary_lineage."""
@@ -498,7 +512,7 @@ class TestQuerier(TestCase):
'count': 1,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)

def test_filter_summary_lineage_success_6(self):
"""Test the success of filter_summary_lineage."""
@@ -520,7 +534,7 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)

def test_filter_summary_lineage_success_7(self):
"""Test the success of filter_summary_lineage."""
@@ -542,14 +556,14 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)

def test_filter_summary_lineage_success_8(self):
"""Test the success of filter_summary_lineage."""
condition = {
'metric/accuracy': {
'lt': 1.0000006,
'gt': 1.0000004
'lt': 1.6000006,
'gt': 1.4000004
}
}
expected_result = {
@@ -558,7 +572,7 @@ class TestQuerier(TestCase):
'count': 1,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)

def test_filter_summary_lineage_success_9(self):
"""Test the success of filter_summary_lineage."""
@@ -572,14 +586,14 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)

def test_filter_summary_lineage_fail(self):
"""Test the function of filter_summary_lineage with exception."""
condition = {
'xxx': {
'lt': 1.0000006,
'gt': 1.0000004
'lt': 1.6000006,
'gt': 1.4000004
}
}
self.assertRaises(


+ 19
- 14
tests/ut/lineagemgr/querier/test_query_model.py View File

@@ -21,6 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj

from . import event_data
from .test_querier import create_filtration_result, create_lineage_info
from ....utils.tools import deal_float_for_dict


class TestLineageObj(TestCase):
@@ -50,27 +51,31 @@ class TestLineageObj(TestCase):
evaluation_lineage=lineage_info.eval_lineage
)

def _assert_dict_equal(self, dict1, dict2):
deal_float_for_dict(dict1, dict2)
self.assertDictEqual(dict1, dict2)

def test_property(self):
"""Test the function of getting property."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj.algorithm
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj.model
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj.train_dataset
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj.hyper_parameters
)
self.assertDictEqual(event_data.METRIC_0, self.lineage_obj.metric)
self.assertDictEqual(
self._assert_dict_equal(event_data.METRIC_0, self.lineage_obj.metric)
self._assert_dict_equal(
event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
self.lineage_obj.valid_dataset
)
@@ -78,24 +83,24 @@ class TestLineageObj(TestCase):
def test_property_eval_not_exist(self):
"""Test the function of getting property with no evaluation event."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj_no_eval.algorithm
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj_no_eval.model
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj_no_eval.train_dataset
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj_no_eval.hyper_parameters
)
self.assertDictEqual({}, self.lineage_obj_no_eval.metric)
self.assertDictEqual({}, self.lineage_obj_no_eval.valid_dataset)
self._assert_dict_equal({}, self.lineage_obj_no_eval.metric)
self._assert_dict_equal({}, self.lineage_obj_no_eval.valid_dataset)

def test_get_summary_info(self):
"""Test the function of get_summary_info."""
@@ -106,7 +111,7 @@ class TestLineageObj(TestCase):
'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']
}
result = self.lineage_obj.get_summary_info(filter_keys)
self.assertDictEqual(expected_result, result)
self._assert_dict_equal(expected_result, result)

def test_to_model_lineage_dict(self):
"""Test the function of to_model_lineage_dict."""
@@ -120,7 +125,7 @@ class TestLineageObj(TestCase):
expected_result['model_lineage']['dataset_mark'] = None
expected_result.pop('dataset_graph')
result = self.lineage_obj.to_model_lineage_dict()
self.assertDictEqual(expected_result, result)
self._assert_dict_equal(expected_result, result)

def test_to_dataset_lineage_dict(self):
"""Test the function of to_dataset_lineage_dict."""


+ 56
- 0
tests/utils/tools.py View File

@@ -81,3 +81,59 @@ def compare_result_with_file(result, expected_file_path):
with open(expected_file_path, 'r') as file:
expected_results = json.load(file)
assert result == expected_results


def deal_float_for_dict(res: dict, expected_res: dict):
"""
Deal float rounded to five decimals in dict.

For example:
res:{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}
After:
res:{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}

Args:
res (dict): e.g.
{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res (dict):
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}


"""
decimal_num = 5
for key in res:
value = res[key]
expected_value = expected_res[key]
if isinstance(value, dict):
deal_float_for_dict(value, expected_value)
elif isinstance(value, float):
res[key] = round(value, decimal_num)
expected_res[key] = round(expected_value, decimal_num)

Loading…
Cancel
Save