|
@@ -25,6 +25,7 @@ from sklearn.model_selection import RandomizedSearchCV |
|
|
from sklearn.exceptions import ConvergenceWarning |
|
|
from sklearn.exceptions import ConvergenceWarning |
|
|
|
|
|
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
from mindarmour.utils.logger import LogUtil |
|
|
|
|
|
from mindarmour.utils._check_param import check_pair_numpy_param, check_param_type |
|
|
|
|
|
|
|
|
LOGGER = LogUtil.get_instance() |
|
|
LOGGER = LogUtil.get_instance() |
|
|
TAG = "Attacker" |
|
|
TAG = "Attacker" |
|
@@ -143,6 +144,13 @@ def get_attack_model(features, labels, config, n_jobs=-1): |
|
|
>>> config = {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}} |
|
|
>>> config = {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}} |
|
|
>>> attack_model = get_attack_model(features, labels, config) |
|
|
>>> attack_model = get_attack_model(features, labels, config) |
|
|
""" |
|
|
""" |
|
|
|
|
|
features, labels = check_pair_numpy_param("features", features, "labels", labels) |
|
|
|
|
|
config = check_param_type("config", config, dict) |
|
|
|
|
|
n_jobs = check_param_type("n_jobs", n_jobs, int) |
|
|
|
|
|
if not (n_jobs == -1 or n_jobs > 0): |
|
|
|
|
|
msg = "Value of n_jobs must be -1 or positive integer." |
|
|
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
method = str.lower(config["method"]) |
|
|
method = str.lower(config["method"]) |
|
|
with warnings.catch_warnings(): |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings('ignore', category=ConvergenceWarning) |
|
|
warnings.filterwarnings('ignore', category=ConvergenceWarning) |
|
|