Merge pull request !108 from liuluobin/mastertags/v1.0.0
| @@ -0,0 +1,220 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Verify attack config | |||||
| """ | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = "check_params" | |||||
| def _is_positive_int(item): | |||||
| """ | |||||
| Verify that the value is a positive integer. | |||||
| """ | |||||
| if not isinstance(item, int) or item <= 0: | |||||
| return False | |||||
| return True | |||||
| def _is_non_negative_int(item): | |||||
| """ | |||||
| Verify that the value is a non-negative integer. | |||||
| """ | |||||
| if not isinstance(item, int) or item < 0: | |||||
| return False | |||||
| return True | |||||
| def _is_positive_float(item): | |||||
| """ | |||||
| Verify that value is a positive number. | |||||
| """ | |||||
| if not isinstance(item, (int, float)) or item <= 0: | |||||
| return False | |||||
| return True | |||||
| def _is_non_negative_float(item): | |||||
| """ | |||||
| Verify that value is a non-negative number. | |||||
| """ | |||||
| if not isinstance(item, (int, float)) or item < 0: | |||||
| return False | |||||
| return True | |||||
| def _is_positive_int_tuple(item): | |||||
| """ | |||||
| Verify that the input parameter is a positive integer tuple. | |||||
| """ | |||||
| if not isinstance(item, tuple): | |||||
| return False | |||||
| for i in item: | |||||
| if not _is_positive_int(i): | |||||
| return False | |||||
| return True | |||||
| def _is_dict(item): | |||||
| """ | |||||
| Check whether the type is dict. | |||||
| """ | |||||
| return isinstance(item, dict) | |||||
| VALID_PARAMS_DICT = { | |||||
| "knn": { | |||||
| "n_neighbors": [_is_positive_int], | |||||
| "weights": [{"uniform", "distance"}], | |||||
| "algorithm": [{"auto", "ball_tree", "kd_tree", "brute"}], | |||||
| "leaf_size": [_is_positive_int], | |||||
| "p": [_is_positive_int], | |||||
| "metric": None, | |||||
| "metric_params": None, | |||||
| }, | |||||
| "lr": { | |||||
| "penalty": [{"l1", "l2", "elasticnet", "none"}], | |||||
| "dual": [{True, False}], | |||||
| "tol": [_is_positive_float], | |||||
| "C": [_is_positive_float], | |||||
| "fit_intercept": [{True, False}], | |||||
| "intercept_scaling": [_is_positive_float], | |||||
| "class_weight": [{"balanced", None}, _is_dict], | |||||
| "random_state": None, | |||||
| "solver": [{"newton-cg", "lbfgs", "liblinear", "sag", "saga"}] | |||||
| }, | |||||
| "mlp": { | |||||
| "hidden_layer_sizes": [_is_positive_int_tuple], | |||||
| "activation": [{"identity", "logistic", "tanh", "relu"}], | |||||
| "solver": {"lbfgs", "sgd", "adam"}, | |||||
| "alpha": [_is_positive_float], | |||||
| "batch_size": [{"auto"}, _is_positive_int], | |||||
| "learning_rate": [{"constant", "invscaling", "adaptive"}], | |||||
| "learning_rate_init": [_is_positive_float], | |||||
| "power_t": [_is_positive_float], | |||||
| "max_iter": [_is_positive_int], | |||||
| "shuffle": [{True, False}], | |||||
| "random_state": None, | |||||
| "tol": [_is_positive_float], | |||||
| "verbose": [{True, False}], | |||||
| "warm_start": [{True, False}], | |||||
| "momentum": [_is_positive_float], | |||||
| "nesterovs_momentum": [{True, False}], | |||||
| "early_stopping": [{True, False}], | |||||
| "validation_fraction": [_is_positive_float], | |||||
| "beta_1": [_is_positive_float], | |||||
| "beta_2": [_is_positive_float], | |||||
| "epsilon": [_is_positive_float], | |||||
| "n_iter_no_change": [_is_positive_int], | |||||
| "max_fun": [_is_positive_int] | |||||
| }, | |||||
| "rf": { | |||||
| "n_estimators": [_is_positive_int], | |||||
| "criterion": [{"gini", "entropy"}], | |||||
| "max_depth": [_is_positive_int], | |||||
| "min_samples_split": [_is_positive_float], | |||||
| "min_samples_leaf": [_is_positive_float], | |||||
| "min_weight_fraction_leaf": [_is_non_negative_float], | |||||
| "max_features": [{"auto", "sqrt", "log2", None}, _is_positive_float], | |||||
| "max_leaf_nodes": [_is_positive_int, {None}], | |||||
| "min_impurity_decrease": {_is_non_negative_float}, | |||||
| "min_impurity_split": [{None}, _is_positive_float], | |||||
| "bootstrap": [{True, False}], | |||||
| "oob_scroe": [{True, False}], | |||||
| "n_jobs": [_is_positive_int, {None}], | |||||
| "random_state": None, | |||||
| "verbose": [_is_non_negative_int], | |||||
| "warm_start": [{True, False}], | |||||
| "class_weight": None, | |||||
| "ccp_alpha": [_is_non_negative_float], | |||||
| "max_samples": [_is_positive_float] | |||||
| } | |||||
| } | |||||
| def _check_config(config_list, check_params): | |||||
| """ | |||||
| Verify that config_list is valid. | |||||
| Check_params is the valid value range of the parameter. | |||||
| """ | |||||
| if not isinstance(config_list, (list, tuple)): | |||||
| msg = "Type of parameter 'config_list' must be list, but got {}.".format(type(config_list)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| for config in config_list: | |||||
| if not isinstance(config, dict): | |||||
| msg = "Type of each config in config_list must be dict, but got {}.".format(type(config)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if set(config.keys()) != {"params", "method"}: | |||||
| msg = "Keys of each config in config_list must be {}," \ | |||||
| "but got {}.".format({'method', 'params'}, set(config.keys())) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise KeyError(msg) | |||||
| method = str.lower(config["method"]) | |||||
| params = config["params"] | |||||
| if method not in check_params.keys(): | |||||
| msg = "Method {} is not supported.".format(method) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| if not params.keys() <= check_params[method].keys(): | |||||
| msg = "Params in method {} is not accepted, the parameters " \ | |||||
| "that can be set are {}.".format(method, set(check_params[method].keys())) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise KeyError(msg) | |||||
| for param_key in params.keys(): | |||||
| param_value = params[param_key] | |||||
| candidate_values = check_params[method][param_key] | |||||
| if not isinstance(param_value, list): | |||||
| msg = "The parameter '{}' in method '{}' setting must within the range of " \ | |||||
| "changeable parameters.".format(param_key, method) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| if candidate_values is None: | |||||
| continue | |||||
| for item_value in param_value: | |||||
| flag = False | |||||
| for candidate_value in candidate_values: | |||||
| if isinstance(candidate_value, set) and item_value in candidate_value: | |||||
| flag = True | |||||
| break | |||||
| elif candidate_value(item_value): | |||||
| flag = True | |||||
| break | |||||
| if not flag: | |||||
| msg = "Setting of parmeter {} in method {} is invalid".format(param_key, method) | |||||
| raise ValueError(msg) | |||||
| def check_config_params(config_list): | |||||
| """ | |||||
| External interfaces to verify attack config. | |||||
| """ | |||||
| _check_config(config_list, VALID_PARAMS_DICT) | |||||
| @@ -27,7 +27,7 @@ LOGGER = LogUtil.get_instance() | |||||
| TAG = "Attacker" | TAG = "Attacker" | ||||
| def _attack_knn(features, labels, param_grid): | |||||
| def _attack_knn(features, labels, param_grid, n_jobs): | |||||
| """ | """ | ||||
| Train and return a KNN model. | Train and return a KNN model. | ||||
| @@ -35,20 +35,21 @@ def _attack_knn(features, labels, param_grid): | |||||
| features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
| labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
| param_grid (dict): Setting of GridSearchCV. | param_grid (dict): Setting of GridSearchCV. | ||||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
| otherwise the value of n_jobs must be a positive integer. | |||||
| Returns: | Returns: | ||||
| sklearn.model_selection.GridSearchCV, trained model. | sklearn.model_selection.GridSearchCV, trained model. | ||||
| """ | """ | ||||
| knn_model = KNeighborsClassifier() | knn_model = KNeighborsClassifier() | ||||
| knn_model = GridSearchCV( | knn_model = GridSearchCV( | ||||
| knn_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||||
| verbose=0, | |||||
| knn_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0, | |||||
| ) | ) | ||||
| knn_model.fit(X=features, y=labels) | knn_model.fit(X=features, y=labels) | ||||
| return knn_model | return knn_model | ||||
| def _attack_lr(features, labels, param_grid): | |||||
| def _attack_lr(features, labels, param_grid, n_jobs): | |||||
| """ | """ | ||||
| Train and return a LR model. | Train and return a LR model. | ||||
| @@ -56,20 +57,21 @@ def _attack_lr(features, labels, param_grid): | |||||
| features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
| labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
| param_grid (dict): Setting of GridSearchCV. | param_grid (dict): Setting of GridSearchCV. | ||||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
| otherwise the value of n_jobs must be a positive integer. | |||||
| Returns: | Returns: | ||||
| sklearn.model_selection.GridSearchCV, trained model. | sklearn.model_selection.GridSearchCV, trained model. | ||||
| """ | """ | ||||
| lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=1000) | |||||
| lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=300) | |||||
| lr_model = GridSearchCV( | lr_model = GridSearchCV( | ||||
| lr_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||||
| verbose=0, | |||||
| lr_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0, | |||||
| ) | ) | ||||
| lr_model.fit(X=features, y=labels) | lr_model.fit(X=features, y=labels) | ||||
| return lr_model | return lr_model | ||||
| def _attack_mlpc(features, labels, param_grid): | |||||
| def _attack_mlpc(features, labels, param_grid, n_jobs): | |||||
| """ | """ | ||||
| Train and return a MLPC model. | Train and return a MLPC model. | ||||
| @@ -77,20 +79,21 @@ def _attack_mlpc(features, labels, param_grid): | |||||
| features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
| labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
| param_grid (dict): Setting of GridSearchCV. | param_grid (dict): Setting of GridSearchCV. | ||||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
| otherwise the value of n_jobs must be a positive integer. | |||||
| Returns: | Returns: | ||||
| sklearn.model_selection.GridSearchCV, trained model. | sklearn.model_selection.GridSearchCV, trained model. | ||||
| """ | """ | ||||
| mlpc_model = MLPClassifier(random_state=1, max_iter=300) | mlpc_model = MLPClassifier(random_state=1, max_iter=300) | ||||
| mlpc_model = GridSearchCV( | mlpc_model = GridSearchCV( | ||||
| mlpc_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||||
| verbose=0, | |||||
| mlpc_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0, | |||||
| ) | ) | ||||
| mlpc_model.fit(features, labels) | mlpc_model.fit(features, labels) | ||||
| return mlpc_model | return mlpc_model | ||||
| def _attack_rf(features, labels, random_grid): | |||||
| def _attack_rf(features, labels, random_grid, n_jobs): | |||||
| """ | """ | ||||
| Train and return a RF model. | Train and return a RF model. | ||||
| @@ -98,20 +101,22 @@ def _attack_rf(features, labels, random_grid): | |||||
| features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
| labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
| random_grid (dict): Setting of RandomizedSearchCV. | random_grid (dict): Setting of RandomizedSearchCV. | ||||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
| otherwise the value of n_jobs must be a positive integer. | |||||
| Returns: | Returns: | ||||
| sklearn.model_selection.RandomizedSearchCV, trained model. | sklearn.model_selection.RandomizedSearchCV, trained model. | ||||
| """ | """ | ||||
| rf_model = RandomForestClassifier(max_depth=2, random_state=0) | rf_model = RandomForestClassifier(max_depth=2, random_state=0) | ||||
| rf_model = RandomizedSearchCV( | rf_model = RandomizedSearchCV( | ||||
| rf_model, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=1, | |||||
| iid=False, verbose=0, | |||||
| rf_model, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=n_jobs, | |||||
| verbose=0, | |||||
| ) | ) | ||||
| rf_model.fit(features, labels) | rf_model.fit(features, labels) | ||||
| return rf_model | return rf_model | ||||
| def get_attack_model(features, labels, config): | |||||
| def get_attack_model(features, labels, config, n_jobs=-1): | |||||
| """ | """ | ||||
| Get trained attack model specify by config. | Get trained attack model specify by config. | ||||
| @@ -123,6 +128,8 @@ def get_attack_model(features, labels, config): | |||||
| params of each method must within the range of changeable parameters. | params of each method must within the range of changeable parameters. | ||||
| Tips of params implement can be found in | Tips of params implement can be found in | ||||
| "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". | "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". | ||||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
| otherwise the value of n_jobs must be a positive integer. | |||||
| Returns: | Returns: | ||||
| sklearn.BaseEstimator, trained model specify by config["method"]. | sklearn.BaseEstimator, trained model specify by config["method"]. | ||||
| @@ -136,13 +143,13 @@ def get_attack_model(features, labels, config): | |||||
| method = str.lower(config["method"]) | method = str.lower(config["method"]) | ||||
| if method == "knn": | if method == "knn": | ||||
| return _attack_knn(features, labels, config["params"]) | |||||
| return _attack_knn(features, labels, config["params"], n_jobs) | |||||
| if method == "lr": | if method == "lr": | ||||
| return _attack_lr(features, labels, config["params"]) | |||||
| return _attack_lr(features, labels, config["params"], n_jobs) | |||||
| if method == "mlp": | if method == "mlp": | ||||
| return _attack_mlpc(features, labels, config["params"]) | |||||
| return _attack_mlpc(features, labels, config["params"], n_jobs) | |||||
| if method == "rf": | if method == "rf": | ||||
| return _attack_rf(features, labels, config["params"]) | |||||
| return _attack_rf(features, labels, config["params"], n_jobs) | |||||
| msg = "Method {} is not supported.".format(config["method"]) | msg = "Method {} is not supported.".format(config["method"]) | ||||
| LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
| @@ -15,14 +15,16 @@ | |||||
| Membership Inference | Membership Inference | ||||
| """ | """ | ||||
| from multiprocessing import cpu_count | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.dataset.engine import Dataset | from mindspore.dataset.engine import Dataset | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindarmour.diff_privacy.evaluation.attacker import get_attack_model | |||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| from .attacker import get_attack_model | |||||
| from ._check_config import check_config_params | |||||
| LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
| TAG = "MembershipInference" | TAG = "MembershipInference" | ||||
| @@ -101,13 +103,15 @@ class MembershipInference: | |||||
| Args: | Args: | ||||
| model (Model): Target model. | model (Model): Target model. | ||||
| n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
| otherwise the value of n_jobs must be a positive integer. | |||||
| Examples: | Examples: | ||||
| >>> train_1, train_2 are non-overlapping datasets from training dataset of target model. | >>> train_1, train_2 are non-overlapping datasets from training dataset of target model. | ||||
| >>> test_1, test_2 are non-overlapping datasets from test dataset of target model. | >>> test_1, test_2 are non-overlapping datasets from test dataset of target model. | ||||
| >>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. | >>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. | ||||
| >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | ||||
| >>> inference_model = MembershipInference(model) | |||||
| >>> inference_model = MembershipInference(model, n_jobs=-1) | |||||
| >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | ||||
| >>> inference_model.train(train_1, test_1, config) | >>> inference_model.train(train_1, test_1, config) | ||||
| >>> metrics = ["precision", "recall", "accuracy"] | >>> metrics = ["precision", "recall", "accuracy"] | ||||
| @@ -115,15 +119,26 @@ class MembershipInference: | |||||
| Raises: | Raises: | ||||
| TypeError: If type of model is not mindspore.train.Model. | TypeError: If type of model is not mindspore.train.Model. | ||||
| TypeError: If type of n_jobs is not int. | |||||
| ValueError: The value of n_jobs is neither -1 nor a positive integer. | |||||
| """ | """ | ||||
| def __init__(self, model): | |||||
| def __init__(self, model, n_jobs=-1): | |||||
| if not isinstance(model, Model): | if not isinstance(model, Model): | ||||
| msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) | msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) | ||||
| LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
| raise TypeError(msg) | raise TypeError(msg) | ||||
| if not isinstance(n_jobs, int): | |||||
| msg = "Type of parameter 'n_jobs' must be int, but got {}".format(type(n_jobs)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if not (n_jobs == -1 or n_jobs > 0): | |||||
| msg = "Value of n_jobs must be either -1 or positive integer, but got {}.".format(n_jobs) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| self.model = model | self.model = model | ||||
| self.n_jobs = min(n_jobs, cpu_count()) | |||||
| self.method_list = ["knn", "lr", "mlp", "rf"] | self.method_list = ["knn", "lr", "mlp", "rf"] | ||||
| self.attack_list = [] | self.attack_list = [] | ||||
| @@ -162,24 +177,13 @@ class MembershipInference: | |||||
| LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
| raise TypeError(msg) | raise TypeError(msg) | ||||
| for config in attack_config: | |||||
| if not isinstance(config, dict): | |||||
| msg = "Type of each config in 'attack_config' must be dict, but got {}.".format(type(config)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if {"params", "method"} != set(config.keys()): | |||||
| msg = "Each config in attack_config must have keys 'method' and 'params'," \ | |||||
| "but your key value is {}.".format(set(config.keys())) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise KeyError(msg) | |||||
| if str.lower(config["method"]) not in self.method_list: | |||||
| msg = "Method {} is not support.".format(config["method"]) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| check_config_params(attack_config) # Verify attack config. | |||||
| features, labels = self._transform(dataset_train, dataset_test) | features, labels = self._transform(dataset_train, dataset_test) | ||||
| for config in attack_config: | for config in attack_config: | ||||
| self.attack_list.append(get_attack_model(features, labels, config)) | |||||
| self.attack_list.append(get_attack_model(features, labels, config, n_jobs=self.n_jobs)) | |||||
| def eval(self, dataset_train, dataset_test, metrics): | def eval(self, dataset_train, dataset_test, metrics): | ||||
| """ | """ | ||||
| @@ -35,7 +35,7 @@ def test_get_knn_model(): | |||||
| "n_neighbors": [3], | "n_neighbors": [3], | ||||
| } | } | ||||
| } | } | ||||
| knn_attacker = get_attack_model(features, labels, config_knn) | |||||
| knn_attacker = get_attack_model(features, labels, config_knn, -1) | |||||
| pred = knn_attacker.predict(features) | pred = knn_attacker.predict(features) | ||||
| assert pred is not None | assert pred is not None | ||||
| @@ -54,7 +54,7 @@ def test_get_lr_model(): | |||||
| "C": np.logspace(-4, 2, 10), | "C": np.logspace(-4, 2, 10), | ||||
| } | } | ||||
| } | } | ||||
| lr_attacker = get_attack_model(features, labels, config_lr) | |||||
| lr_attacker = get_attack_model(features, labels, config_lr, -1) | |||||
| pred = lr_attacker.predict(features) | pred = lr_attacker.predict(features) | ||||
| assert pred is not None | assert pred is not None | ||||
| @@ -75,7 +75,7 @@ def test_get_mlp_model(): | |||||
| "alpha": [0.0001, 0.001, 0.01], | "alpha": [0.0001, 0.001, 0.01], | ||||
| } | } | ||||
| } | } | ||||
| mlpc_attacker = get_attack_model(features, labels, config_mlpc) | |||||
| mlpc_attacker = get_attack_model(features, labels, config_mlpc, -1) | |||||
| pred = mlpc_attacker.predict(features) | pred = mlpc_attacker.predict(features) | ||||
| assert pred is not None | assert pred is not None | ||||
| @@ -98,6 +98,6 @@ def test_get_rf_model(): | |||||
| "min_samples_leaf": [1, 2, 4], | "min_samples_leaf": [1, 2, 4], | ||||
| } | } | ||||
| } | } | ||||
| rf_attacker = get_attack_model(features, labels, config_rf) | |||||
| rf_attacker = get_attack_model(features, labels, config_rf, -1) | |||||
| pred = rf_attacker.predict(features) | pred = rf_attacker.predict(features) | ||||
| assert pred is not None | assert pred is not None | ||||
| @@ -24,6 +24,7 @@ import numpy as np | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| import mindspore.context as context | |||||
| from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference | from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference | ||||
| @@ -31,6 +32,8 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) | |||||
| from defenses.mock_net import Net | from defenses.mock_net import Net | ||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def dataset_generator(batch_size, batches): | def dataset_generator(batch_size, batches): | ||||
| """mock training data.""" | """mock training data.""" | ||||
| data = np.random.randn(batches*batch_size, 1, 32, 32).astype( | data = np.random.randn(batches*batch_size, 1, 32, 32).astype( | ||||
| @@ -51,7 +54,7 @@ def test_get_membership_inference_object(): | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
| opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| model = Model(network=net, loss_fn=loss, optimizer=opt) | model = Model(network=net, loss_fn=loss, optimizer=opt) | ||||
| inference_model = MembershipInference(model) | |||||
| inference_model = MembershipInference(model, -1) | |||||
| assert isinstance(inference_model, MembershipInference) | assert isinstance(inference_model, MembershipInference) | ||||
| @@ -65,7 +68,7 @@ def test_membership_inference_object_train(): | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
| opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| model = Model(network=net, loss_fn=loss, optimizer=opt) | model = Model(network=net, loss_fn=loss, optimizer=opt) | ||||
| inference_model = MembershipInference(model) | |||||
| inference_model = MembershipInference(model, -1) | |||||
| assert isinstance(inference_model, MembershipInference) | assert isinstance(inference_model, MembershipInference) | ||||
| config = [{ | config = [{ | ||||
| @@ -95,7 +98,7 @@ def test_membership_inference_eval(): | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
| opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| model = Model(network=net, loss_fn=loss, optimizer=opt) | model = Model(network=net, loss_fn=loss, optimizer=opt) | ||||
| inference_model = MembershipInference(model) | |||||
| inference_model = MembershipInference(model, -1) | |||||
| assert isinstance(inference_model, MembershipInference) | assert isinstance(inference_model, MembershipInference) | ||||
| batch_size = 16 | batch_size = 16 | ||||