From bfd3d2588f466c43528a5d34274b3f8dae11cf29 Mon Sep 17 00:00:00 2001 From: liuluobin Date: Fri, 16 Apr 2021 09:52:00 +0800 Subject: [PATCH] Add parameter validation for the get_attacker function --- mindarmour/privacy/evaluation/attacker.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mindarmour/privacy/evaluation/attacker.py b/mindarmour/privacy/evaluation/attacker.py index 2d798db..cc71c84 100644 --- a/mindarmour/privacy/evaluation/attacker.py +++ b/mindarmour/privacy/evaluation/attacker.py @@ -25,6 +25,7 @@ from sklearn.model_selection import RandomizedSearchCV from sklearn.exceptions import ConvergenceWarning from mindarmour.utils.logger import LogUtil +from mindarmour.utils._check_param import check_pair_numpy_param, check_param_type LOGGER = LogUtil.get_instance() TAG = "Attacker" @@ -143,6 +144,13 @@ def get_attack_model(features, labels, config, n_jobs=-1): >>> config = {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}} >>> attack_model = get_attack_model(features, labels, config) """ + features, labels = check_pair_numpy_param("features", features, "labels", labels) + config = check_param_type("config", config, dict) + n_jobs = check_param_type("n_jobs", n_jobs, int) + if not (n_jobs == -1 or n_jobs > 0): + msg = "Value of n_jobs must be -1 or positive integer." + raise ValueError(msg) + method = str.lower(config["method"]) with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=ConvergenceWarning)