Browse Source

Adjust config check. Modify the exception type to TypeError in function check_model.

tags/v1.0.0
liuluobin 4 years ago
parent
commit
2ffdf100ce
3 changed files with 41 additions and 39 deletions
  1. +39
    -37
      mindarmour/privacy/evaluation/_check_config.py
  2. +1
    -1
      mindarmour/utils/_check_param.py
  3. +1
    -1
      tests/ut/python/detectors/test_region_based_detector.py

+ 39
- 37
mindarmour/privacy/evaluation/_check_config.py View File

@@ -24,45 +24,39 @@ TAG = "check_config"


def _is_positive_int(item):
"""
Verify that the value is a positive integer.
"""
if not isinstance(item, int) or item <= 0:
"""Verify that the value is a positive integer."""
if not isinstance(item, int):
return False
return True

return item > 0

def _is_non_negative_int(item):
"""
Verify that the value is a non-negative integer.
"""
if not isinstance(item, int) or item < 0:
"""Verify that the value is a non-negative integer."""
if not isinstance(item, int):
return False
return True
return item >= 0


def _is_positive_float(item):
"""
Verify that value is a positive number.
"""
if not isinstance(item, (int, float)) or item <= 0:
"""Verify that value is a positive number."""
if not isinstance(item, (int, float)):
return False
return True
return item > 0


def _is_non_negative_float(item):
"""
Verify that value is a non-negative number.
"""
if not isinstance(item, (int, float)) or item < 0:
"""Verify that value is a non-negative number."""
if not isinstance(item, (int, float)):
return False
return True
return item >= 0

def _is_range_0_1_float(item):
if not isinstance(item, (int, float)):
return False
return 0 <= item < 1


def _is_positive_int_tuple(item):
"""
Verify that the input parameter is a positive integer tuple.
"""
"""Verify that the input parameter is a positive integer tuple."""
if not isinstance(item, tuple):
return False
for i in item:
@@ -72,21 +66,29 @@ def _is_positive_int_tuple(item):


def _is_dict(item):
"""
Check whether the type is dict.
"""
"""Check whether the type is dict."""
return isinstance(item, dict)


def _is_list(item):
"""Check whether the type is list"""
return isinstance(item, list)


def _is_str(item):
"""Check whether the type is str."""
return isinstance(item, str)


_VALID_CONFIG_CHECKLIST = {
"knn": {
"n_neighbors": [_is_positive_int],
"weights": [{"uniform", "distance"}],
"weights": [{"uniform", "distance"}, callable],
"algorithm": [{"auto", "ball_tree", "kd_tree", "brute"}],
"leaf_size": [_is_positive_int],
"p": [_is_positive_int],
"metric": None,
"metric_params": None,
"metric": [_is_str, callable],
"metric_params": [_is_dict, {None}]
},
"lr": {
"penalty": [{"l1", "l2", "elasticnet", "none"}],
@@ -102,7 +104,7 @@ _VALID_CONFIG_CHECKLIST = {
"mlp": {
"hidden_layer_sizes": [_is_positive_int_tuple],
"activation": [{"identity", "logistic", "tanh", "relu"}],
"solver": {"lbfgs", "sgd", "adam"},
"solver": [{"lbfgs", "sgd", "adam"}],
"alpha": [_is_positive_float],
"batch_size": [{"auto"}, _is_positive_int],
"learning_rate": [{"constant", "invscaling", "adaptive"}],
@@ -117,9 +119,9 @@ _VALID_CONFIG_CHECKLIST = {
"momentum": [_is_positive_float],
"nesterovs_momentum": [{True, False}],
"early_stopping": [{True, False}],
"validation_fraction": [_is_positive_float],
"beta_1": [_is_positive_float],
"beta_2": [_is_positive_float],
"validation_fraction": [_is_range_0_1_float],
"beta_1": [_is_range_0_1_float],
"beta_2": [_is_range_0_1_float],
"epsilon": [_is_positive_float],
"n_iter_no_change": [_is_positive_int],
"max_fun": [_is_positive_int]
@@ -133,7 +135,7 @@ _VALID_CONFIG_CHECKLIST = {
"min_weight_fraction_leaf": [_is_non_negative_float],
"max_features": [{"auto", "sqrt", "log2", None}, _is_positive_float],
"max_leaf_nodes": [_is_positive_int, {None}],
"min_impurity_decrease": {_is_non_negative_float},
"min_impurity_decrease": [_is_non_negative_float],
"min_impurity_split": [{None}, _is_positive_float],
"bootstrap": [{True, False}],
"oob_scroe": [{True, False}],
@@ -141,9 +143,9 @@ _VALID_CONFIG_CHECKLIST = {
"random_state": None,
"verbose": [_is_non_negative_int],
"warm_start": [{True, False}],
"class_weight": None,
"class_weight": [{"balanced", "balanced_subsample"}, _is_dict, _is_list],
"ccp_alpha": [_is_non_negative_float],
"max_samples": [_is_positive_float]
"max_samples": [{None}, _is_positive_int, _is_range_0_1_float]
}
}



+ 1
- 1
mindarmour/utils/_check_param.py View File

@@ -129,7 +129,7 @@ def check_model(model_name, model, model_type):
model_type,
type(model).__name__)
LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)


def check_numpy_param(arg_name, arg_value):


+ 1
- 1
tests/ut/python/detectors/test_region_based_detector.py View File

@@ -84,7 +84,7 @@ def test_value_error():
adv = np.random.rand(4, 4).astype(np.float32)
model = Model(Net())
# model should be mindspore model
with pytest.raises(ValueError):
with pytest.raises(TypeError):
assert RegionBasedDetector(Net())

with pytest.raises(ValueError):


Loading…
Cancel
Save