Browse Source

!761 enhance float cmp in tests.lineagemgr in r0.3, update securec repository link

Merge pull request !761 from luopengting/fix_r0.3
r0.3
mindspore-ci-bot Gitee 5 years ago
parent
commit
488cfaefa2
8 changed files with 113 additions and 118 deletions
  1. +1
    -1
      .gitmodules
  2. +22
    -20
      tests/st/func/lineagemgr/api/test_model_api.py
  3. +4
    -5
      tests/st/func/lineagemgr/cache/test_lineage_cache.py
  4. +9
    -29
      tests/st/func/lineagemgr/collection/model/test_model_lineage.py
  5. +2
    -1
      tests/ut/lineagemgr/querier/event_data.py
  6. +16
    -30
      tests/ut/lineagemgr/querier/test_querier.py
  7. +37
    -28
      tests/ut/lineagemgr/querier/test_query_model.py
  8. +22
    -4
      tests/utils/tools.py

+ 1
- 1
.gitmodules View File

@@ -1,3 +1,3 @@
[submodule "third_party/securec"] [submodule "third_party/securec"]
path = third_party/securec path = third_party/securec
url = https://gitee.com/openeuler/bounds_checking_function.git
url = https://gitee.com/openeuler/libboundscheck.git

+ 22
- 20
tests/st/func/lineagemgr/api/test_model_api.py View File

@@ -31,6 +31,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF
LineageSearchConditionParamError) LineageSearchConditionParamError)
from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
from .....ut.lineagemgr.querier import event_data from .....ut.lineagemgr.querier import event_data
from .....utils.tools import assert_equal_lineages


LINEAGE_INFO_RUN1 = { LINEAGE_INFO_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
@@ -39,7 +40,7 @@ LINEAGE_INFO_RUN1 = {
}, },
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'loss_function': 'SoftmaxCrossEntropyWithLogits', 'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14, 'epoch': 14,
'parallel_mode': 'stand_alone', 'parallel_mode': 'stand_alone',
@@ -73,11 +74,11 @@ LINEAGE_FILTRATION_EXCEPT_RUN = {
'user_defined': {}, 'user_defined': {},
'network': 'ResNet', 'network': 'ResNet',
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'epoch': 10, 'epoch': 10,
'batch_size': 32, 'batch_size': 32,
'device_num': 2, 'device_num': 2,
'loss': 0.029999999329447746,
'loss': 0.03,
'model_size': 64, 'model_size': 64,
'metric': {}, 'metric': {},
'dataset_mark': 2 'dataset_mark': 2
@@ -92,10 +93,14 @@ LINEAGE_FILTRATION_RUN1 = {
'train_dataset_count': 1024, 'train_dataset_count': 1024,
'test_dataset_path': None, 'test_dataset_path': None,
'test_dataset_count': 1024, 'test_dataset_count': 1024,
'user_defined': {'info': 'info1', 'version': 'v1'},
'user_defined': {
'info': 'info1',
'version': 'v1',
'eval_version': 'version2'
},
'network': 'ResNet', 'network': 'ResNet',
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'epoch': 14, 'epoch': 14,
'batch_size': 32, 'batch_size': 32,
'device_num': 2, 'device_num': 2,
@@ -119,14 +124,14 @@ LINEAGE_FILTRATION_RUN2 = {
'user_defined': {}, 'user_defined': {},
'network': "ResNet", 'network': "ResNet",
'optimizer': "Momentum", 'optimizer': "Momentum",
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'epoch': 10, 'epoch': 10,
'batch_size': 32, 'batch_size': 32,
'device_num': 2, 'device_num': 2,
'loss': 0.029999999329447746,
'loss': 0.03,
'model_size': 10, 'model_size': 10,
'metric': { 'metric': {
'accuracy': 2.7800000000000002
'accuracy': 2.78
}, },
'dataset_mark': 3 'dataset_mark': 3
}, },
@@ -173,7 +178,7 @@ class TestModelApi(TestCase):
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'loss_function': 'SoftmaxCrossEntropyWithLogits', 'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14, 'epoch': 14,
'parallel_mode': 'stand_alone', 'parallel_mode': 'stand_alone',
@@ -190,9 +195,9 @@ class TestModelApi(TestCase):
'network': 'ResNet' 'network': 'ResNet'
} }
} }
assert expect_total_res == total_res
assert expect_partial_res1 == partial_res1
assert expect_partial_res2 == partial_res2
assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual)
assert_equal_lineages(expect_partial_res1, partial_res1, self.assertDictEqual)
assert_equal_lineages(expect_partial_res2, partial_res2, self.assertDictEqual)


# the lineage summary file is empty # the lineage summary file is empty
result = get_summary_lineage(self.dir_with_empty_lineage) result = get_summary_lineage(self.dir_with_empty_lineage)
@@ -345,7 +350,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res
assert_equal_lineages(expect_result, res, self.assertDictEqual)


expect_result = { expect_result = {
'customized': {}, 'customized': {},
@@ -356,7 +361,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res
assert_equal_lineages(expect_result, res, self.assertDictEqual)


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@@ -394,7 +399,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res
assert_equal_lineages(expect_result, partial_res, self.assertDictEqual)


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@@ -432,7 +437,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res
assert_equal_lineages(expect_result, partial_res, self.assertDictEqual)


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@@ -461,7 +466,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res1.get('object')): for idx, res_object in enumerate(partial_res1.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res1
assert_equal_lineages(expect_result, partial_res1, self.assertDictEqual)


search_condition2 = { search_condition2 = {
'batch_size': { 'batch_size': {
@@ -477,9 +482,6 @@ class TestModelApi(TestCase):
'count': 0 'count': 0
} }
partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2) partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res2.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res2 assert expect_result == partial_res2


@pytest.mark.level0 @pytest.mark.level0


+ 4
- 5
tests/st/func/lineagemgr/cache/test_lineage_cache.py View File

@@ -33,7 +33,7 @@ from ..api.test_model_api import LINEAGE_INFO_RUN1, LINEAGE_FILTRATION_EXCEPT_RU
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2 LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2
from ..conftest import BASE_SUMMARY_DIR from ..conftest import BASE_SUMMARY_DIR
from .....ut.lineagemgr.querier import event_data from .....ut.lineagemgr.querier import event_data
from .....utils.tools import check_loading_done
from .....utils.tools import check_loading_done, assert_equal_lineages




@pytest.mark.usefixtures("create_summary_dir") @pytest.mark.usefixtures("create_summary_dir")
@@ -58,8 +58,7 @@ class TestModelApi(TestCase):
"""Test the interface of get_summary_lineage.""" """Test the interface of get_summary_lineage."""
total_res = general_get_summary_lineage(data_manager=self._data_manger, summary_dir="./run1") total_res = general_get_summary_lineage(data_manager=self._data_manger, summary_dir="./run1")
expect_total_res = LINEAGE_INFO_RUN1 expect_total_res = LINEAGE_INFO_RUN1

assert expect_total_res == total_res
assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual)


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@@ -86,7 +85,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res
assert_equal_lineages(expect_result, res, self.assertDictEqual)


expect_result = { expect_result = {
'customized': {}, 'customized': {},
@@ -100,4 +99,4 @@ class TestModelApi(TestCase):
} }
} }
res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition) res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition)
assert expect_result == res
assert_equal_lineages(expect_result, res, self.assertDictEqual)

+ 9
- 29
tests/st/func/lineagemgr/collection/model/test_model_lineage.py View File

@@ -73,6 +73,10 @@ class TestModelLineage(TestCase):
TrainLineage(cls.summary_record) TrainLineage(cls.summary_record)
] ]
cls.run_context['list_callback'] = _ListCallback(callback) cls.run_context['list_callback'] = _ListCallback(callback)
cls.user_defined_info = {
"info": "info1",
"version": "v1"
}


@pytest.mark.scene_train(2) @pytest.mark.scene_train(2)
@pytest.mark.level0 @pytest.mark.level0
@@ -83,7 +87,7 @@ class TestModelLineage(TestCase):
@pytest.mark.env_single @pytest.mark.env_single
def test_train_begin(self): def test_train_begin(self):
"""Test the begin function in TrainLineage.""" """Test the begin function in TrainLineage."""
train_callback = TrainLineage(self.summary_record, True)
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback.begin(RunContext(self.run_context)) train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12 assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path lineage_log_path = train_callback.lineage_summary.lineage_log_path
@@ -98,30 +102,6 @@ class TestModelLineage(TestCase):
@pytest.mark.env_single @pytest.mark.env_single
def test_train_begin_with_user_defined_info(self): def test_train_begin_with_user_defined_info(self):
"""Test TrainLineage with nested user defined info.""" """Test TrainLineage with nested user defined info."""
user_defined_info = {"info": {"version": "v1"}}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True

@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_key_in_lineage(self):
"""Test TrainLineage with nested user defined info."""
expected_res = {
"info": "info1",
"version": "v1"
}
user_defined_info = { user_defined_info = {
"info": "info1", "info": "info1",
"version": "v1", "version": "v1",
@@ -137,7 +117,7 @@ class TestModelLineage(TestCase):
lineage_log_path = train_callback.lineage_summary.lineage_log_path lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True assert os.path.isfile(lineage_log_path) is True
res = filter_summary_lineage(os.path.dirname(lineage_log_path)) res = filter_summary_lineage(os.path.dirname(lineage_log_path))
assert expected_res == res['object'][0]['model_lineage']['user_defined']
assert self.user_defined_info == res['object'][0]['model_lineage']['user_defined']


@pytest.mark.scene_train(2) @pytest.mark.scene_train(2)
@pytest.mark.level0 @pytest.mark.level0
@@ -168,7 +148,7 @@ class TestModelLineage(TestCase):
def test_training_end(self, *args): def test_training_end(self, *args):
"""Test the end function in TrainLineage.""" """Test the end function in TrainLineage."""
args[0].return_value = 64 args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True)
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback.initial_learning_rate = 0.12 train_callback.initial_learning_rate = 0.12
train_callback.end(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context))
res = get_summary_lineage(SUMMARY_DIR) res = get_summary_lineage(SUMMARY_DIR)
@@ -188,7 +168,7 @@ class TestModelLineage(TestCase):
@pytest.mark.env_single @pytest.mark.env_single
def test_eval_end(self): def test_eval_end(self):
"""Test the end function in EvalLineage.""" """Test the end function in EvalLineage."""
eval_callback = EvalLineage(self.summary_record, True)
eval_callback = EvalLineage(self.summary_record, True, {'eval_version': 'version2'})
eval_run_context = self.run_context eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['metrics'] = {'accuracy': 0.78}
eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['valid_dataset'] = self.run_context['train_dataset']
@@ -361,7 +341,7 @@ class TestModelLineage(TestCase):
def test_train_with_customized_network(self, *args): def test_train_with_customized_network(self, *args):
"""Test train with customized network.""" """Test train with customized network."""
args[0].return_value = 64 args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True)
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
run_context_customized = self.run_context run_context_customized = self.run_context
del run_context_customized['optimizer'] del run_context_customized['optimizer']
del run_context_customized['net_outputs'] del run_context_customized['net_outputs']


+ 2
- 1
tests/ut/lineagemgr/querier/event_data.py View File

@@ -195,7 +195,8 @@ CUSTOMIZED__0 = {
CUSTOMIZED__1 = { CUSTOMIZED__1 = {
**CUSTOMIZED__0, **CUSTOMIZED__0,
'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'}, 'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'},
'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'}
'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'},
'user_defined/eval_version': {'label': 'user_defined/eval_version', 'required': False, 'type': 'str'}
} }


CUSTOMIZED_0 = { CUSTOMIZED_0 = {


+ 16
- 30
tests/ut/lineagemgr/querier/test_querier.py View File

@@ -27,7 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo


from . import event_data from . import event_data
from ....utils.tools import deal_float_for_dict
from ....utils.tools import assert_equal_lineages




def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict):
@@ -282,31 +282,17 @@ class TestQuerier(TestCase):
lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs
self.multi_querier = Querier(lineage_objects) self.multi_querier = Querier(lineage_objects)


def _deal_float_for_list(self, list1, list2):
index = 0
for _ in list1:
deal_float_for_dict(list1[index], list2[index])
index += 1

def _assert_list_equal(self, list1, list2):
self._deal_float_for_list(list1, list2)
self.assertListEqual(list1, list2)

def _assert_lineages_equal(self, lineages1, lineages2):
self._deal_float_for_list(lineages1['object'], lineages2['object'])
self.assertDictEqual(lineages1, lineages2)

def test_get_summary_lineage_success_1(self): def test_get_summary_lineage_success_1(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self._assert_list_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_2(self): def test_get_summary_lineage_success_2(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self._assert_list_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_3(self): def test_get_summary_lineage_success_3(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -320,7 +306,7 @@ class TestQuerier(TestCase):
result = self.single_querier.get_summary_lineage( result = self.single_querier.get_summary_lineage(
filter_keys=['model', 'algorithm'] filter_keys=['model', 'algorithm']
) )
self._assert_list_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_4(self): def test_get_summary_lineage_success_4(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -367,7 +353,7 @@ class TestQuerier(TestCase):
} }
] ]
result = self.multi_querier.get_summary_lineage() result = self.multi_querier.get_summary_lineage()
self._assert_list_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_5(self): def test_get_summary_lineage_success_5(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -375,7 +361,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary1' summary_dir='/path/to/summary1'
) )
self._assert_list_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_6(self): def test_get_summary_lineage_success_6(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -394,7 +380,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary0', filter_keys=filter_keys summary_dir='/path/to/summary0', filter_keys=filter_keys
) )
self._assert_list_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_fail(self): def test_get_summary_lineage_fail(self):
"""Test the function of get_summary_lineage with exception.""" """Test the function of get_summary_lineage with exception."""
@@ -437,7 +423,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_2(self): def test_filter_summary_lineage_success_2(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -462,7 +448,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_3(self): def test_filter_summary_lineage_success_3(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -479,7 +465,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_4(self): def test_filter_summary_lineage_success_4(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -497,7 +483,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage() result = self.multi_querier.filter_summary_lineage()
self._assert_lineages_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_5(self): def test_filter_summary_lineage_success_5(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -512,7 +498,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_6(self): def test_filter_summary_lineage_success_6(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -534,7 +520,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_7(self): def test_filter_summary_lineage_success_7(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -556,7 +542,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_8(self): def test_filter_summary_lineage_success_8(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -572,7 +558,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_9(self): def test_filter_summary_lineage_success_9(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -586,7 +572,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_fail(self): def test_filter_summary_lineage_fail(self):
"""Test the function of filter_summary_lineage with exception.""" """Test the function of filter_summary_lineage with exception."""


+ 37
- 28
tests/ut/lineagemgr/querier/test_query_model.py View File

@@ -21,7 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj


from . import event_data from . import event_data
from .test_querier import create_filtration_result, create_lineage_info from .test_querier import create_filtration_result, create_lineage_info
from ....utils.tools import deal_float_for_dict
from ....utils.tools import assert_equal_lineages




class TestLineageObj(TestCase): class TestLineageObj(TestCase):
@@ -51,56 +51,65 @@ class TestLineageObj(TestCase):
evaluation_lineage=lineage_info.eval_lineage evaluation_lineage=lineage_info.eval_lineage
) )


def _assert_dict_equal(self, dict1, dict2):
deal_float_for_dict(dict1, dict2)
self.assertDictEqual(dict1, dict2)

def test_property(self): def test_property(self):
"""Test the function of getting property.""" """Test the function of getting property."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self._assert_dict_equal(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj.algorithm
self.lineage_obj.algorithm,
self.assertDictEqual
) )
self._assert_dict_equal(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj.model
self.lineage_obj.model,
self.assertDictEqual
) )
self._assert_dict_equal(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj.train_dataset
self.lineage_obj.train_dataset,
self.assertDictEqual
) )
self._assert_dict_equal(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj.hyper_parameters
self.lineage_obj.hyper_parameters,
self.assertDictEqual
)
assert_equal_lineages(
event_data.METRIC_0,
self.lineage_obj.metric,
self.assertDictEqual
) )
self._assert_dict_equal(event_data.METRIC_0, self.lineage_obj.metric)
self._assert_dict_equal(
assert_equal_lineages(
event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
self.lineage_obj.valid_dataset
self.lineage_obj.valid_dataset,
self.assertDictEqual
) )


def test_property_eval_not_exist(self): def test_property_eval_not_exist(self):
"""Test the function of getting property with no evaluation event.""" """Test the function of getting property with no evaluation event."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self._assert_dict_equal(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj_no_eval.algorithm
self.lineage_obj_no_eval.algorithm,
self.assertDictEqual
) )
self._assert_dict_equal(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj_no_eval.model
self.lineage_obj_no_eval.model,
self.assertDictEqual
) )
self._assert_dict_equal(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj_no_eval.train_dataset
self.lineage_obj_no_eval.train_dataset,
self.assertDictEqual
) )
self._assert_dict_equal(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj_no_eval.hyper_parameters
self.lineage_obj_no_eval.hyper_parameters,
self.assertDictEqual
) )
self._assert_dict_equal({}, self.lineage_obj_no_eval.metric)
self._assert_dict_equal({}, self.lineage_obj_no_eval.valid_dataset)
assert_equal_lineages({}, self.lineage_obj_no_eval.metric, self.assertDictEqual)
assert_equal_lineages({}, self.lineage_obj_no_eval.valid_dataset, self.assertDictEqual)


def test_get_summary_info(self): def test_get_summary_info(self):
"""Test the function of get_summary_info.""" """Test the function of get_summary_info."""
@@ -111,7 +120,7 @@ class TestLineageObj(TestCase):
'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']
} }
result = self.lineage_obj.get_summary_info(filter_keys) result = self.lineage_obj.get_summary_info(filter_keys)
self._assert_dict_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_to_model_lineage_dict(self): def test_to_model_lineage_dict(self):
"""Test the function of to_model_lineage_dict.""" """Test the function of to_model_lineage_dict."""
@@ -125,7 +134,7 @@ class TestLineageObj(TestCase):
expected_result['model_lineage']['dataset_mark'] = None expected_result['model_lineage']['dataset_mark'] = None
expected_result.pop('dataset_graph') expected_result.pop('dataset_graph')
result = self.lineage_obj.to_model_lineage_dict() result = self.lineage_obj.to_model_lineage_dict()
self._assert_dict_equal(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_to_dataset_lineage_dict(self): def test_to_dataset_lineage_dict(self):
"""Test the function of to_dataset_lineage_dict.""" """Test the function of to_dataset_lineage_dict."""


+ 22
- 4
tests/utils/tools.py View File

@@ -83,9 +83,9 @@ def compare_result_with_file(result, expected_file_path):
assert result == expected_results assert result == expected_results




def deal_float_for_dict(res: dict, expected_res: dict):
def deal_float_for_dict(res: dict, expected_res: dict, decimal_num=2):
""" """
Deal float rounded to five decimals in dict.
Deal float rounded to specified decimals in dict.


For example: For example:
res:{ res:{
@@ -125,10 +125,9 @@ def deal_float_for_dict(res: dict, expected_res: dict):
"metric": {"acc": 0.1234562} "metric": {"acc": 0.1234562}
} }
} }
decimal_num (int): decimal rounded digits.


""" """
decimal_num = 5
for key in res: for key in res:
value = res[key] value = res[key]
expected_value = expected_res[key] expected_value = expected_res[key]
@@ -137,3 +136,22 @@ def deal_float_for_dict(res: dict, expected_res: dict):
elif isinstance(value, float): elif isinstance(value, float):
res[key] = round(value, decimal_num) res[key] = round(value, decimal_num)
expected_res[key] = round(expected_value, decimal_num) expected_res[key] = round(expected_value, decimal_num)


def _deal_float_for_list(list1, list2, decimal_num):
"""Deal float for list1 and list2."""
index = 0
for _ in list1:
deal_float_for_dict(list1[index], list2[index], decimal_num)
index += 1


def assert_equal_lineages(lineages1, lineages2, assert_func, decimal_num=2):
"""Assert float almost equal for lineage data."""
if isinstance(lineages1, list) and isinstance(lineages2, list):
_deal_float_for_list(lineages1, lineages2, decimal_num)
elif lineages1.get('object') is not None and lineages2.get('object') is not None:
_deal_float_for_list(lineages1['object'], lineages2['object'], decimal_num)
else:
deal_float_for_dict(lineages1, lineages2, decimal_num)
assert_func(lineages1, lineages2)

Loading…
Cancel
Save