Browse Source

!206 Add parameter validation for the get_attacker function

From: @liu_luobin
Reviewed-by: @pkuliuliu,@zhidanliu
Signed-off-by: @pkuliuliu,@zhidanliu
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
5192ae469b
1 changed files with 8 additions and 0 deletions
  1. +8
    -0
      mindarmour/privacy/evaluation/attacker.py

+ 8
- 0
mindarmour/privacy/evaluation/attacker.py View File

@@ -25,6 +25,7 @@ from sklearn.model_selection import RandomizedSearchCV
from sklearn.exceptions import ConvergenceWarning from sklearn.exceptions import ConvergenceWarning


from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_pair_numpy_param, check_param_type


LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
TAG = "Attacker" TAG = "Attacker"
@@ -143,6 +144,13 @@ def get_attack_model(features, labels, config, n_jobs=-1):
>>> config = {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}} >>> config = {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}
>>> attack_model = get_attack_model(features, labels, config) >>> attack_model = get_attack_model(features, labels, config)
""" """
features, labels = check_pair_numpy_param("features", features, "labels", labels)
config = check_param_type("config", config, dict)
n_jobs = check_param_type("n_jobs", n_jobs, int)
if not (n_jobs == -1 or n_jobs > 0):
msg = "Value of n_jobs must be -1 or positive integer."
raise ValueError(msg)

method = str.lower(config["method"]) method = str.lower(config["method"])
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=ConvergenceWarning) warnings.filterwarnings('ignore', category=ConvergenceWarning)


Loading…
Cancel
Save