Browse Source

remove invalid item in user_defined info, record valid items

pull/196/head
luopengting 5 years ago
parent
commit
d15c81c699
6 changed files with 52 additions and 13 deletions
  1. +2
    -2
      mindinsight/lineagemgr/collection/model/model_lineage.py
  2. +7
    -4
      mindinsight/lineagemgr/common/validator/validate.py
  3. +5
    -5
      tests/st/func/lineagemgr/api/test_model_api.py
  4. +1
    -1
      tests/st/func/lineagemgr/cache/test_lineage_cache.py
  5. +31
    -1
      tests/st/func/lineagemgr/collection/model/test_model_lineage.py
  6. +6
    -0
      tests/ut/lineagemgr/querier/event_data.py

+ 2
- 2
mindinsight/lineagemgr/collection/model/model_lineage.py View File

@@ -284,8 +284,8 @@ class EvalLineage(Callback):
self.lineage_summary = LineageSummary(self.lineage_log_dir)

self.user_defined_info = user_defined_info
if user_defined_info:
validate_user_defined_info(user_defined_info)
if self.user_defined_info:
validate_user_defined_info(self.user_defined_info)

except MindInsightException as err:
log.error(err)


+ 7
- 4
mindinsight/lineagemgr/common/validator/validate.py View File

@@ -410,7 +410,7 @@ def validate_path(summary_path):

def validate_user_defined_info(user_defined_info):
"""
Validate user defined info.
Validate user defined info, delete the item if its key is in lineage.

Args:
user_defined_info (dict): The user defined info.
@@ -437,10 +437,13 @@ def validate_user_defined_info(user_defined_info):

field_map = set(FIELD_MAPPING.keys())
user_defined_keys = set(user_defined_info.keys())
all_keys = field_map | user_defined_keys
insertion = list(field_map & user_defined_keys)

if len(field_map) + len(user_defined_keys) != len(all_keys):
raise LineageParamValueError("There are some keys have defined in lineage.")
if insertion:
for key in insertion:
user_defined_info.pop(key)
raise LineageParamValueError("There are some keys have defined in lineage. "
"Duplicated key(s): %s. " % insertion)


def validate_train_id(relative_path):


+ 5
- 5
tests/st/func/lineagemgr/api/test_model_api.py View File

@@ -92,7 +92,7 @@ LINEAGE_FILTRATION_RUN1 = {
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': 1024,
'user_defined': {},
'user_defined': {'info': 'info1', 'version': 'v1'},
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
@@ -329,7 +329,7 @@ class TestModelApi(TestCase):
def test_filter_summary_lineage(self):
"""Test the interface of filter_summary_lineage."""
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1,
@@ -383,7 +383,7 @@ class TestModelApi(TestCase):
'offset': 0
}
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1
@@ -421,7 +421,7 @@ class TestModelApi(TestCase):
'offset': 0
}
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1
@@ -449,7 +449,7 @@ class TestModelApi(TestCase):
'sorted_name': 'metric/accuracy',
}
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1,


+ 1
- 1
tests/st/func/lineagemgr/cache/test_lineage_cache.py View File

@@ -70,7 +70,7 @@ class TestModelApi(TestCase):
def test_filter_summary_lineage(self):
"""Test the interface of filter_summary_lineage."""
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1,


+ 31
- 1
tests/st/func/lineagemgr/collection/model/test_model_lineage.py View File

@@ -28,7 +28,7 @@ from unittest import mock, TestCase
import numpy as np
import pytest

from mindinsight.lineagemgr import get_summary_lineage
from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \
AnalyzeObject
from mindinsight.lineagemgr.common.utils import make_directory
@@ -109,6 +109,36 @@ class TestModelLineage(TestCase):
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True

@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_key_in_lineage(self):
"""Test TrainLineage with nested user defined info."""
expected_res = {
"info": "info1",
"version": "v1"
}
user_defined_info = {
"info": "info1",
"version": "v1",
"network": "LeNet"
}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
res = filter_summary_lineage(os.path.dirname(lineage_log_path))
assert expected_res == res['object'][0]['model_lineage']['user_defined']

@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training


+ 6
- 0
tests/ut/lineagemgr/querier/event_data.py View File

@@ -192,6 +192,12 @@ CUSTOMIZED__0 = {
'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'},
}

CUSTOMIZED__1 = {
**CUSTOMIZED__0,
'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'},
'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'}
}

CUSTOMIZED_0 = {
**CUSTOMIZED__0,
'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'},


Loading…
Cancel
Save