| @@ -0,0 +1,220 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Verify attack config | |||
| """ | |||
| from mindarmour.utils.logger import LogUtil | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = "check_params" | |||
| def _is_positive_int(item): | |||
| """ | |||
| Verify that the value is a positive integer. | |||
| """ | |||
| if not isinstance(item, int) or item <= 0: | |||
| return False | |||
| return True | |||
| def _is_non_negative_int(item): | |||
| """ | |||
| Verify that the value is a non-negative integer. | |||
| """ | |||
| if not isinstance(item, int) or item < 0: | |||
| return False | |||
| return True | |||
| def _is_positive_float(item): | |||
| """ | |||
| Verify that value is a positive number. | |||
| """ | |||
| if not isinstance(item, (int, float)) or item <= 0: | |||
| return False | |||
| return True | |||
| def _is_non_negative_float(item): | |||
| """ | |||
| Verify that value is a non-negative number. | |||
| """ | |||
| if not isinstance(item, (int, float)) or item < 0: | |||
| return False | |||
| return True | |||
| def _is_positive_int_tuple(item): | |||
| """ | |||
| Verify that the input parameter is a positive integer tuple. | |||
| """ | |||
| if not isinstance(item, tuple): | |||
| return False | |||
| for i in item: | |||
| if not _is_positive_int(i): | |||
| return False | |||
| return True | |||
| def _is_dict(item): | |||
| """ | |||
| Check whether the type is dict. | |||
| """ | |||
| return isinstance(item, dict) | |||
| VALID_PARAMS_DICT = { | |||
| "knn": { | |||
| "n_neighbors": [_is_positive_int], | |||
| "weights": [{"uniform", "distance"}], | |||
| "algorithm": [{"auto", "ball_tree", "kd_tree", "brute"}], | |||
| "leaf_size": [_is_positive_int], | |||
| "p": [_is_positive_int], | |||
| "metric": None, | |||
| "metric_params": None, | |||
| }, | |||
| "lr": { | |||
| "penalty": [{"l1", "l2", "elasticnet", "none"}], | |||
| "dual": [{True, False}], | |||
| "tol": [_is_positive_float], | |||
| "C": [_is_positive_float], | |||
| "fit_intercept": [{True, False}], | |||
| "intercept_scaling": [_is_positive_float], | |||
| "class_weight": [{"balanced", None}, _is_dict], | |||
| "random_state": None, | |||
| "solver": [{"newton-cg", "lbfgs", "liblinear", "sag", "saga"}] | |||
| }, | |||
| "mlp": { | |||
| "hidden_layer_sizes": [_is_positive_int_tuple], | |||
| "activation": [{"identity", "logistic", "tanh", "relu"}], | |||
| "solver": {"lbfgs", "sgd", "adam"}, | |||
| "alpha": [_is_positive_float], | |||
| "batch_size": [{"auto"}, _is_positive_int], | |||
| "learning_rate": [{"constant", "invscaling", "adaptive"}], | |||
| "learning_rate_init": [_is_positive_float], | |||
| "power_t": [_is_positive_float], | |||
| "max_iter": [_is_positive_int], | |||
| "shuffle": [{True, False}], | |||
| "random_state": None, | |||
| "tol": [_is_positive_float], | |||
| "verbose": [{True, False}], | |||
| "warm_start": [{True, False}], | |||
| "momentum": [_is_positive_float], | |||
| "nesterovs_momentum": [{True, False}], | |||
| "early_stopping": [{True, False}], | |||
| "validation_fraction": [_is_positive_float], | |||
| "beta_1": [_is_positive_float], | |||
| "beta_2": [_is_positive_float], | |||
| "epsilon": [_is_positive_float], | |||
| "n_iter_no_change": [_is_positive_int], | |||
| "max_fun": [_is_positive_int] | |||
| }, | |||
| "rf": { | |||
| "n_estimators": [_is_positive_int], | |||
| "criterion": [{"gini", "entropy"}], | |||
| "max_depth": [_is_positive_int], | |||
| "min_samples_split": [_is_positive_float], | |||
| "min_samples_leaf": [_is_positive_float], | |||
| "min_weight_fraction_leaf": [_is_non_negative_float], | |||
| "max_features": [{"auto", "sqrt", "log2", None}, _is_positive_float], | |||
| "max_leaf_nodes": [_is_positive_int, {None}], | |||
| "min_impurity_decrease": {_is_non_negative_float}, | |||
| "min_impurity_split": [{None}, _is_positive_float], | |||
| "bootstrap": [{True, False}], | |||
| "oob_scroe": [{True, False}], | |||
| "n_jobs": [_is_positive_int, {None}], | |||
| "random_state": None, | |||
| "verbose": [_is_non_negative_int], | |||
| "warm_start": [{True, False}], | |||
| "class_weight": None, | |||
| "ccp_alpha": [_is_non_negative_float], | |||
| "max_samples": [_is_positive_float] | |||
| } | |||
| } | |||
| def _check_config(config_list, check_params): | |||
| """ | |||
| Verify that config_list is valid. | |||
| Check_params is the valid value range of the parameter. | |||
| """ | |||
| if not isinstance(config_list, (list, tuple)): | |||
| msg = "Type of parameter 'config_list' must be list, but got {}.".format(type(config_list)) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| for config in config_list: | |||
| if not isinstance(config, dict): | |||
| msg = "Type of each config in config_list must be dict, but got {}.".format(type(config)) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| if set(config.keys()) != {"params", "method"}: | |||
| msg = "Keys of each config in config_list must be {}," \ | |||
| "but got {}.".format({'method', 'params'}, set(config.keys())) | |||
| LOGGER.error(TAG, msg) | |||
| raise KeyError(msg) | |||
| method = str.lower(config["method"]) | |||
| params = config["params"] | |||
| if method not in check_params.keys(): | |||
| msg = "Method {} is not supported.".format(method) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| if not params.keys() <= check_params[method].keys(): | |||
| msg = "Params in method {} is not accepted, the parameters " \ | |||
| "that can be set are {}.".format(method, set(check_params[method].keys())) | |||
| LOGGER.error(TAG, msg) | |||
| raise KeyError(msg) | |||
| for param_key in params.keys(): | |||
| param_value = params[param_key] | |||
| candidate_values = check_params[method][param_key] | |||
| if not isinstance(param_value, list): | |||
| msg = "The parameter '{}' in method '{}' setting must within the range of " \ | |||
| "changeable parameters.".format(param_key, method) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| if candidate_values is None: | |||
| continue | |||
| for item_value in param_value: | |||
| flag = False | |||
| for candidate_value in candidate_values: | |||
| if isinstance(candidate_value, set) and item_value in candidate_value: | |||
| flag = True | |||
| break | |||
| elif candidate_value(item_value): | |||
| flag = True | |||
| break | |||
| if not flag: | |||
| msg = "Setting of parmeter {} in method {} is invalid".format(param_key, method) | |||
| raise ValueError(msg) | |||
| def check_config_params(config_list): | |||
| """ | |||
| External interfaces to verify attack config. | |||
| """ | |||
| _check_config(config_list, VALID_PARAMS_DICT) | |||
| @@ -27,7 +27,7 @@ LOGGER = LogUtil.get_instance() | |||
| TAG = "Attacker" | |||
| def _attack_knn(features, labels, param_grid): | |||
| def _attack_knn(features, labels, param_grid, n_jobs): | |||
| """ | |||
| Train and return a KNN model. | |||
| @@ -35,20 +35,21 @@ def _attack_knn(features, labels, param_grid): | |||
| features (numpy.ndarray): Loss and logits characteristics of each sample. | |||
| labels (numpy.ndarray): Labels of each sample whether belongs to training set. | |||
| param_grid (dict): Setting of GridSearchCV. | |||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||
| otherwise the value of n_jobs must be a positive integer. | |||
| Returns: | |||
| sklearn.model_selection.GridSearchCV, trained model. | |||
| """ | |||
| knn_model = KNeighborsClassifier() | |||
| knn_model = GridSearchCV( | |||
| knn_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||
| verbose=0, | |||
| knn_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0, | |||
| ) | |||
| knn_model.fit(X=features, y=labels) | |||
| return knn_model | |||
| def _attack_lr(features, labels, param_grid): | |||
| def _attack_lr(features, labels, param_grid, n_jobs): | |||
| """ | |||
| Train and return a LR model. | |||
| @@ -56,20 +57,21 @@ def _attack_lr(features, labels, param_grid): | |||
| features (numpy.ndarray): Loss and logits characteristics of each sample. | |||
| labels (numpy.ndarray): Labels of each sample whether belongs to training set. | |||
| param_grid (dict): Setting of GridSearchCV. | |||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||
| otherwise the value of n_jobs must be a positive integer. | |||
| Returns: | |||
| sklearn.model_selection.GridSearchCV, trained model. | |||
| """ | |||
| lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=1000) | |||
| lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=300) | |||
| lr_model = GridSearchCV( | |||
| lr_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||
| verbose=0, | |||
| lr_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0, | |||
| ) | |||
| lr_model.fit(X=features, y=labels) | |||
| return lr_model | |||
| def _attack_mlpc(features, labels, param_grid): | |||
| def _attack_mlpc(features, labels, param_grid, n_jobs): | |||
| """ | |||
| Train and return a MLPC model. | |||
| @@ -77,20 +79,21 @@ def _attack_mlpc(features, labels, param_grid): | |||
| features (numpy.ndarray): Loss and logits characteristics of each sample. | |||
| labels (numpy.ndarray): Labels of each sample whether belongs to training set. | |||
| param_grid (dict): Setting of GridSearchCV. | |||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||
| otherwise the value of n_jobs must be a positive integer. | |||
| Returns: | |||
| sklearn.model_selection.GridSearchCV, trained model. | |||
| """ | |||
| mlpc_model = MLPClassifier(random_state=1, max_iter=300) | |||
| mlpc_model = GridSearchCV( | |||
| mlpc_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||
| verbose=0, | |||
| mlpc_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0, | |||
| ) | |||
| mlpc_model.fit(features, labels) | |||
| return mlpc_model | |||
| def _attack_rf(features, labels, random_grid): | |||
| def _attack_rf(features, labels, random_grid, n_jobs): | |||
| """ | |||
| Train and return a RF model. | |||
| @@ -98,20 +101,22 @@ def _attack_rf(features, labels, random_grid): | |||
| features (numpy.ndarray): Loss and logits characteristics of each sample. | |||
| labels (numpy.ndarray): Labels of each sample whether belongs to training set. | |||
| random_grid (dict): Setting of RandomizedSearchCV. | |||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||
| otherwise the value of n_jobs must be a positive integer. | |||
| Returns: | |||
| sklearn.model_selection.RandomizedSearchCV, trained model. | |||
| """ | |||
| rf_model = RandomForestClassifier(max_depth=2, random_state=0) | |||
| rf_model = RandomizedSearchCV( | |||
| rf_model, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=1, | |||
| iid=False, verbose=0, | |||
| rf_model, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=n_jobs, | |||
| verbose=0, | |||
| ) | |||
| rf_model.fit(features, labels) | |||
| return rf_model | |||
| def get_attack_model(features, labels, config): | |||
| def get_attack_model(features, labels, config, n_jobs=-1): | |||
| """ | |||
| Get trained attack model specify by config. | |||
| @@ -123,6 +128,8 @@ def get_attack_model(features, labels, config): | |||
| params of each method must within the range of changeable parameters. | |||
| Tips of params implement can be found in | |||
| "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". | |||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||
| otherwise the value of n_jobs must be a positive integer. | |||
| Returns: | |||
| sklearn.BaseEstimator, trained model specify by config["method"]. | |||
| @@ -136,13 +143,13 @@ def get_attack_model(features, labels, config): | |||
| method = str.lower(config["method"]) | |||
| if method == "knn": | |||
| return _attack_knn(features, labels, config["params"]) | |||
| return _attack_knn(features, labels, config["params"], n_jobs) | |||
| if method == "lr": | |||
| return _attack_lr(features, labels, config["params"]) | |||
| return _attack_lr(features, labels, config["params"], n_jobs) | |||
| if method == "mlp": | |||
| return _attack_mlpc(features, labels, config["params"]) | |||
| return _attack_mlpc(features, labels, config["params"], n_jobs) | |||
| if method == "rf": | |||
| return _attack_rf(features, labels, config["params"]) | |||
| return _attack_rf(features, labels, config["params"], n_jobs) | |||
| msg = "Method {} is not supported.".format(config["method"]) | |||
| LOGGER.error(TAG, msg) | |||
| @@ -15,14 +15,16 @@ | |||
| Membership Inference | |||
| """ | |||
| from multiprocessing import cpu_count | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore.train import Model | |||
| from mindspore.dataset.engine import Dataset | |||
| from mindspore import Tensor | |||
| from mindarmour.diff_privacy.evaluation.attacker import get_attack_model | |||
| from mindarmour.utils.logger import LogUtil | |||
| from .attacker import get_attack_model | |||
| from ._check_config import check_config_params | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = "MembershipInference" | |||
| @@ -101,13 +103,15 @@ class MembershipInference: | |||
| Args: | |||
| model (Model): Target model. | |||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||
| otherwise the value of n_jobs must be a positive integer. | |||
| Examples: | |||
| >>> train_1, train_2 are non-overlapping datasets from training dataset of target model. | |||
| >>> test_1, test_2 are non-overlapping datasets from test dataset of target model. | |||
| >>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. | |||
| >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | |||
| >>> inference_model = MembershipInference(model) | |||
| >>> inference_model = MembershipInference(model, n_jobs=-1) | |||
| >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | |||
| >>> inference_model.train(train_1, test_1, config) | |||
| >>> metrics = ["precision", "recall", "accuracy"] | |||
| @@ -115,15 +119,26 @@ class MembershipInference: | |||
| Raises: | |||
| TypeError: If type of model is not mindspore.train.Model. | |||
| TypeError: If type of n_jobs is not int. | |||
| ValueError: The value of n_jobs is neither -1 nor a positive integer. | |||
| """ | |||
| def __init__(self, model): | |||
| def __init__(self, model, n_jobs=-1): | |||
| if not isinstance(model, Model): | |||
| msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| if not isinstance(n_jobs, int): | |||
| msg = "Type of parameter 'n_jobs' must be int, but got {}".format(type(n_jobs)) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| if not (n_jobs == -1 or n_jobs > 0): | |||
| msg = "Value of n_jobs must be either -1 or positive integer, but got {}.".format(n_jobs) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| self.model = model | |||
| self.n_jobs = min(n_jobs, cpu_count()) | |||
| self.method_list = ["knn", "lr", "mlp", "rf"] | |||
| self.attack_list = [] | |||
| @@ -162,24 +177,13 @@ class MembershipInference: | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| for config in attack_config: | |||
| if not isinstance(config, dict): | |||
| msg = "Type of each config in 'attack_config' must be dict, but got {}.".format(type(config)) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| if {"params", "method"} != set(config.keys()): | |||
| msg = "Each config in attack_config must have keys 'method' and 'params'," \ | |||
| "but your key value is {}.".format(set(config.keys())) | |||
| LOGGER.error(TAG, msg) | |||
| raise KeyError(msg) | |||
| if str.lower(config["method"]) not in self.method_list: | |||
| msg = "Method {} is not support.".format(config["method"]) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| check_config_params(attack_config) # Verify attack config. | |||
| features, labels = self._transform(dataset_train, dataset_test) | |||
| for config in attack_config: | |||
| self.attack_list.append(get_attack_model(features, labels, config)) | |||
| self.attack_list.append(get_attack_model(features, labels, config, n_jobs=self.n_jobs)) | |||
| def eval(self, dataset_train, dataset_test, metrics): | |||
| """ | |||
| @@ -35,7 +35,7 @@ def test_get_knn_model(): | |||
| "n_neighbors": [3], | |||
| } | |||
| } | |||
| knn_attacker = get_attack_model(features, labels, config_knn) | |||
| knn_attacker = get_attack_model(features, labels, config_knn, -1) | |||
| pred = knn_attacker.predict(features) | |||
| assert pred is not None | |||
| @@ -54,7 +54,7 @@ def test_get_lr_model(): | |||
| "C": np.logspace(-4, 2, 10), | |||
| } | |||
| } | |||
| lr_attacker = get_attack_model(features, labels, config_lr) | |||
| lr_attacker = get_attack_model(features, labels, config_lr, -1) | |||
| pred = lr_attacker.predict(features) | |||
| assert pred is not None | |||
| @@ -75,7 +75,7 @@ def test_get_mlp_model(): | |||
| "alpha": [0.0001, 0.001, 0.01], | |||
| } | |||
| } | |||
| mlpc_attacker = get_attack_model(features, labels, config_mlpc) | |||
| mlpc_attacker = get_attack_model(features, labels, config_mlpc, -1) | |||
| pred = mlpc_attacker.predict(features) | |||
| assert pred is not None | |||
| @@ -98,6 +98,6 @@ def test_get_rf_model(): | |||
| "min_samples_leaf": [1, 2, 4], | |||
| } | |||
| } | |||
| rf_attacker = get_attack_model(features, labels, config_rf) | |||
| rf_attacker = get_attack_model(features, labels, config_rf, -1) | |||
| pred = rf_attacker.predict(features) | |||
| assert pred is not None | |||
| @@ -24,6 +24,7 @@ import numpy as np | |||
| import mindspore.dataset as ds | |||
| from mindspore import nn | |||
| from mindspore.train import Model | |||
| import mindspore.context as context | |||
| from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference | |||
| @@ -31,6 +32,8 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) | |||
| from defenses.mock_net import Net | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def dataset_generator(batch_size, batches): | |||
| """mock training data.""" | |||
| data = np.random.randn(batches*batch_size, 1, 32, 32).astype( | |||
| @@ -51,7 +54,7 @@ def test_get_membership_inference_object(): | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
| opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| model = Model(network=net, loss_fn=loss, optimizer=opt) | |||
| inference_model = MembershipInference(model) | |||
| inference_model = MembershipInference(model, -1) | |||
| assert isinstance(inference_model, MembershipInference) | |||
| @@ -65,7 +68,7 @@ def test_membership_inference_object_train(): | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
| opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| model = Model(network=net, loss_fn=loss, optimizer=opt) | |||
| inference_model = MembershipInference(model) | |||
| inference_model = MembershipInference(model, -1) | |||
| assert isinstance(inference_model, MembershipInference) | |||
| config = [{ | |||
| @@ -95,7 +98,7 @@ def test_membership_inference_eval(): | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
| opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| model = Model(network=net, loss_fn=loss, optimizer=opt) | |||
| inference_model = MembershipInference(model) | |||
| inference_model = MembershipInference(model, -1) | |||
| assert isinstance(inference_model, MembershipInference) | |||
| batch_size = 16 | |||