Browse Source

!60 lineagemgr: user defined error not affect other lineage information record when not raise exception in lineage callback

Merge pull request !60 from kouzhenzhong/user_defined_bug_fix
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
e63cc72135
2 changed files with 27 additions and 7 deletions
  1. +8
    -7
      mindinsight/lineagemgr/summary/_summary_adapter.py
  2. +19
    -0
      tests/st/func/lineagemgr/collection/model/test_model_lineage.py

+ 8
- 7
mindinsight/lineagemgr/summary/_summary_adapter.py View File

@@ -327,7 +327,9 @@ def _package_user_defined_info(user_defined_dict, user_defined_message):
"""
for key, value in user_defined_dict.items():
if not isinstance(key, str):
raise LineageParamTypeError("The key must be str.")
error_msg = f"Invalid key type in user defined info. The {key}'s type" \
f"'{type(key).__name__}' is not supported. It should be str."
log.error(error_msg)

if isinstance(value, int):
attr_name = "map_int32"
@@ -336,13 +338,12 @@ def _package_user_defined_info(user_defined_dict, user_defined_message):
elif isinstance(value, str):
attr_name = "map_str"
else:
error_msg = "Value type {} is not supported in user defined event package." \
"Only str, int and float are permitted now.".format(type(value))
log.error(error_msg)
raise LineageParamTypeError(error_msg)
attr_name = "attr_name"

add_user_defined_info = user_defined_message.user_info.add()
try:
getattr(add_user_defined_info, attr_name)[key] = value
except ValueError:
raise LineageParamValueError("Value is out of range or not be supported yet.")
except AttributeError:
error_msg = f"Invalid value type in user defined info. The {value}'s type" \
f"'{type(value).__name__}' is not supported. It should be float, int or str."
log.error(error_msg)

+ 19
- 0
tests/st/func/lineagemgr/collection/model/test_model_lineage.py View File

@@ -88,6 +88,25 @@ class TestModelLineage(TestCase):
lineage_log_path = self.summary_record.full_file_name + '_lineage'
assert os.path.isfile(lineage_log_path) is True

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_info(self):
"""Test TrainLineage with nested user defined info."""
user_defined_info = {"info": {"version": "v1"}}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = self.summary_record.full_file_name + '_lineage'
assert os.path.isfile(lineage_log_path) is True

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training


Loading…
Cancel
Save