From 16aa19579c7b98abfaac2c54ab6b925f2353f4fb Mon Sep 17 00:00:00 2001 From: luopengting Date: Mon, 21 Dec 2020 20:20:34 +0800 Subject: [PATCH] check for param in choice if it is bool If the values for batch_size is 'choice: [1, 2, True]', it will raise validation exception. Additionally, use 'type(value) is int' will not pass pylint. so use 'isinstance(x, bool) or not isinstance(x, int)' to filter invalid value(s). --- mindinsight/optimizer/common/validator/optimizer_config.py | 3 ++- .../ut/optimizer/common/validator/test_optimizer_config.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mindinsight/optimizer/common/validator/optimizer_config.py b/mindinsight/optimizer/common/validator/optimizer_config.py index 9ecf1a2a..ca7341eb 100644 --- a/mindinsight/optimizer/common/validator/optimizer_config.py +++ b/mindinsight/optimizer/common/validator/optimizer_config.py @@ -219,7 +219,8 @@ class OptimizerConfig(Schema): if list(filter(lambda x: not isinstance(x, float), choice)): err_msg = "The value(s) should be float number." raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value)) - elif list(filter(lambda x: not isinstance(x, int), choice)): + elif list(filter(lambda x: isinstance(x, bool) or not isinstance(x, int), choice)): + # isinstance(x, int) will return True if x is bool. use 'type(x)' will not pass lint. err_msg = "The value(s) should be integer." raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value)) diff --git a/tests/ut/optimizer/common/validator/test_optimizer_config.py b/tests/ut/optimizer/common/validator/test_optimizer_config.py index 11f08fae..240e5c25 100644 --- a/tests/ut/optimizer/common/validator/test_optimizer_config.py +++ b/tests/ut/optimizer/common/validator/test_optimizer_config.py @@ -231,6 +231,12 @@ class TestOptimizerConfig: err = OptimizerConfig().validate(config_dict) assert expected_err == err + # test bool + expected_err['parameters'][param_name]['choice'] = 'The value(s) should be integer.' + config_dict['parameters'][param_name]['choice'] = [1, True] + err = OptimizerConfig().validate(config_dict) + assert expected_err == err + config_dict['parameters'][param_name] = {'choice': [0.1, 0.2]} err = OptimizerConfig().validate(config_dict) assert expected_err == err