Merge pull request !108 from liuluobin/mastertags/v1.0.0
@@ -0,0 +1,220 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Verify attack config | |||||
""" | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = "check_params" | |||||
def _is_positive_int(item): | |||||
""" | |||||
Verify that the value is a positive integer. | |||||
""" | |||||
if not isinstance(item, int) or item <= 0: | |||||
return False | |||||
return True | |||||
def _is_non_negative_int(item): | |||||
""" | |||||
Verify that the value is a non-negative integer. | |||||
""" | |||||
if not isinstance(item, int) or item < 0: | |||||
return False | |||||
return True | |||||
def _is_positive_float(item): | |||||
""" | |||||
Verify that value is a positive number. | |||||
""" | |||||
if not isinstance(item, (int, float)) or item <= 0: | |||||
return False | |||||
return True | |||||
def _is_non_negative_float(item): | |||||
""" | |||||
Verify that value is a non-negative number. | |||||
""" | |||||
if not isinstance(item, (int, float)) or item < 0: | |||||
return False | |||||
return True | |||||
def _is_positive_int_tuple(item): | |||||
""" | |||||
Verify that the input parameter is a positive integer tuple. | |||||
""" | |||||
if not isinstance(item, tuple): | |||||
return False | |||||
for i in item: | |||||
if not _is_positive_int(i): | |||||
return False | |||||
return True | |||||
def _is_dict(item): | |||||
""" | |||||
Check whether the type is dict. | |||||
""" | |||||
return isinstance(item, dict) | |||||
VALID_PARAMS_DICT = { | |||||
"knn": { | |||||
"n_neighbors": [_is_positive_int], | |||||
"weights": [{"uniform", "distance"}], | |||||
"algorithm": [{"auto", "ball_tree", "kd_tree", "brute"}], | |||||
"leaf_size": [_is_positive_int], | |||||
"p": [_is_positive_int], | |||||
"metric": None, | |||||
"metric_params": None, | |||||
}, | |||||
"lr": { | |||||
"penalty": [{"l1", "l2", "elasticnet", "none"}], | |||||
"dual": [{True, False}], | |||||
"tol": [_is_positive_float], | |||||
"C": [_is_positive_float], | |||||
"fit_intercept": [{True, False}], | |||||
"intercept_scaling": [_is_positive_float], | |||||
"class_weight": [{"balanced", None}, _is_dict], | |||||
"random_state": None, | |||||
"solver": [{"newton-cg", "lbfgs", "liblinear", "sag", "saga"}] | |||||
}, | |||||
"mlp": { | |||||
"hidden_layer_sizes": [_is_positive_int_tuple], | |||||
"activation": [{"identity", "logistic", "tanh", "relu"}], | |||||
"solver": {"lbfgs", "sgd", "adam"}, | |||||
"alpha": [_is_positive_float], | |||||
"batch_size": [{"auto"}, _is_positive_int], | |||||
"learning_rate": [{"constant", "invscaling", "adaptive"}], | |||||
"learning_rate_init": [_is_positive_float], | |||||
"power_t": [_is_positive_float], | |||||
"max_iter": [_is_positive_int], | |||||
"shuffle": [{True, False}], | |||||
"random_state": None, | |||||
"tol": [_is_positive_float], | |||||
"verbose": [{True, False}], | |||||
"warm_start": [{True, False}], | |||||
"momentum": [_is_positive_float], | |||||
"nesterovs_momentum": [{True, False}], | |||||
"early_stopping": [{True, False}], | |||||
"validation_fraction": [_is_positive_float], | |||||
"beta_1": [_is_positive_float], | |||||
"beta_2": [_is_positive_float], | |||||
"epsilon": [_is_positive_float], | |||||
"n_iter_no_change": [_is_positive_int], | |||||
"max_fun": [_is_positive_int] | |||||
}, | |||||
"rf": { | |||||
"n_estimators": [_is_positive_int], | |||||
"criterion": [{"gini", "entropy"}], | |||||
"max_depth": [_is_positive_int], | |||||
"min_samples_split": [_is_positive_float], | |||||
"min_samples_leaf": [_is_positive_float], | |||||
"min_weight_fraction_leaf": [_is_non_negative_float], | |||||
"max_features": [{"auto", "sqrt", "log2", None}, _is_positive_float], | |||||
"max_leaf_nodes": [_is_positive_int, {None}], | |||||
"min_impurity_decrease": {_is_non_negative_float}, | |||||
"min_impurity_split": [{None}, _is_positive_float], | |||||
"bootstrap": [{True, False}], | |||||
"oob_scroe": [{True, False}], | |||||
"n_jobs": [_is_positive_int, {None}], | |||||
"random_state": None, | |||||
"verbose": [_is_non_negative_int], | |||||
"warm_start": [{True, False}], | |||||
"class_weight": None, | |||||
"ccp_alpha": [_is_non_negative_float], | |||||
"max_samples": [_is_positive_float] | |||||
} | |||||
} | |||||
def _check_config(config_list, check_params): | |||||
""" | |||||
Verify that config_list is valid. | |||||
Check_params is the valid value range of the parameter. | |||||
""" | |||||
if not isinstance(config_list, (list, tuple)): | |||||
msg = "Type of parameter 'config_list' must be list, but got {}.".format(type(config_list)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
for config in config_list: | |||||
if not isinstance(config, dict): | |||||
msg = "Type of each config in config_list must be dict, but got {}.".format(type(config)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if set(config.keys()) != {"params", "method"}: | |||||
msg = "Keys of each config in config_list must be {}," \ | |||||
"but got {}.".format({'method', 'params'}, set(config.keys())) | |||||
LOGGER.error(TAG, msg) | |||||
raise KeyError(msg) | |||||
method = str.lower(config["method"]) | |||||
params = config["params"] | |||||
if method not in check_params.keys(): | |||||
msg = "Method {} is not supported.".format(method) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if not params.keys() <= check_params[method].keys(): | |||||
msg = "Params in method {} is not accepted, the parameters " \ | |||||
"that can be set are {}.".format(method, set(check_params[method].keys())) | |||||
LOGGER.error(TAG, msg) | |||||
raise KeyError(msg) | |||||
for param_key in params.keys(): | |||||
param_value = params[param_key] | |||||
candidate_values = check_params[method][param_key] | |||||
if not isinstance(param_value, list): | |||||
msg = "The parameter '{}' in method '{}' setting must within the range of " \ | |||||
"changeable parameters.".format(param_key, method) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if candidate_values is None: | |||||
continue | |||||
for item_value in param_value: | |||||
flag = False | |||||
for candidate_value in candidate_values: | |||||
if isinstance(candidate_value, set) and item_value in candidate_value: | |||||
flag = True | |||||
break | |||||
elif candidate_value(item_value): | |||||
flag = True | |||||
break | |||||
if not flag: | |||||
msg = "Setting of parmeter {} in method {} is invalid".format(param_key, method) | |||||
raise ValueError(msg) | |||||
def check_config_params(config_list): | |||||
""" | |||||
External interfaces to verify attack config. | |||||
""" | |||||
_check_config(config_list, VALID_PARAMS_DICT) |
@@ -27,7 +27,7 @@ LOGGER = LogUtil.get_instance() | |||||
TAG = "Attacker" | TAG = "Attacker" | ||||
def _attack_knn(features, labels, param_grid): | |||||
def _attack_knn(features, labels, param_grid, n_jobs): | |||||
""" | """ | ||||
Train and return a KNN model. | Train and return a KNN model. | ||||
@@ -35,20 +35,21 @@ def _attack_knn(features, labels, param_grid): | |||||
features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
param_grid (dict): Setting of GridSearchCV. | param_grid (dict): Setting of GridSearchCV. | ||||
n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
otherwise the value of n_jobs must be a positive integer. | |||||
Returns: | Returns: | ||||
sklearn.model_selection.GridSearchCV, trained model. | sklearn.model_selection.GridSearchCV, trained model. | ||||
""" | """ | ||||
knn_model = KNeighborsClassifier() | knn_model = KNeighborsClassifier() | ||||
knn_model = GridSearchCV( | knn_model = GridSearchCV( | ||||
knn_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||||
verbose=0, | |||||
knn_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0, | |||||
) | ) | ||||
knn_model.fit(X=features, y=labels) | knn_model.fit(X=features, y=labels) | ||||
return knn_model | return knn_model | ||||
def _attack_lr(features, labels, param_grid): | |||||
def _attack_lr(features, labels, param_grid, n_jobs): | |||||
""" | """ | ||||
Train and return a LR model. | Train and return a LR model. | ||||
@@ -56,20 +57,21 @@ def _attack_lr(features, labels, param_grid): | |||||
features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
param_grid (dict): Setting of GridSearchCV. | param_grid (dict): Setting of GridSearchCV. | ||||
n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
otherwise the value of n_jobs must be a positive integer. | |||||
Returns: | Returns: | ||||
sklearn.model_selection.GridSearchCV, trained model. | sklearn.model_selection.GridSearchCV, trained model. | ||||
""" | """ | ||||
lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=1000) | |||||
lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=300) | |||||
lr_model = GridSearchCV( | lr_model = GridSearchCV( | ||||
lr_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||||
verbose=0, | |||||
lr_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0, | |||||
) | ) | ||||
lr_model.fit(X=features, y=labels) | lr_model.fit(X=features, y=labels) | ||||
return lr_model | return lr_model | ||||
def _attack_mlpc(features, labels, param_grid): | |||||
def _attack_mlpc(features, labels, param_grid, n_jobs): | |||||
""" | """ | ||||
Train and return a MLPC model. | Train and return a MLPC model. | ||||
@@ -77,20 +79,21 @@ def _attack_mlpc(features, labels, param_grid): | |||||
features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
param_grid (dict): Setting of GridSearchCV. | param_grid (dict): Setting of GridSearchCV. | ||||
n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
otherwise the value of n_jobs must be a positive integer. | |||||
Returns: | Returns: | ||||
sklearn.model_selection.GridSearchCV, trained model. | sklearn.model_selection.GridSearchCV, trained model. | ||||
""" | """ | ||||
mlpc_model = MLPClassifier(random_state=1, max_iter=300) | mlpc_model = MLPClassifier(random_state=1, max_iter=300) | ||||
mlpc_model = GridSearchCV( | mlpc_model = GridSearchCV( | ||||
mlpc_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||||
verbose=0, | |||||
mlpc_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0, | |||||
) | ) | ||||
mlpc_model.fit(features, labels) | mlpc_model.fit(features, labels) | ||||
return mlpc_model | return mlpc_model | ||||
def _attack_rf(features, labels, random_grid): | |||||
def _attack_rf(features, labels, random_grid, n_jobs): | |||||
""" | """ | ||||
Train and return a RF model. | Train and return a RF model. | ||||
@@ -98,20 +101,22 @@ def _attack_rf(features, labels, random_grid): | |||||
features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
random_grid (dict): Setting of RandomizedSearchCV. | random_grid (dict): Setting of RandomizedSearchCV. | ||||
n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
otherwise the value of n_jobs must be a positive integer. | |||||
Returns: | Returns: | ||||
sklearn.model_selection.RandomizedSearchCV, trained model. | sklearn.model_selection.RandomizedSearchCV, trained model. | ||||
""" | """ | ||||
rf_model = RandomForestClassifier(max_depth=2, random_state=0) | rf_model = RandomForestClassifier(max_depth=2, random_state=0) | ||||
rf_model = RandomizedSearchCV( | rf_model = RandomizedSearchCV( | ||||
rf_model, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=1, | |||||
iid=False, verbose=0, | |||||
rf_model, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=n_jobs, | |||||
verbose=0, | |||||
) | ) | ||||
rf_model.fit(features, labels) | rf_model.fit(features, labels) | ||||
return rf_model | return rf_model | ||||
def get_attack_model(features, labels, config): | |||||
def get_attack_model(features, labels, config, n_jobs=-1): | |||||
""" | """ | ||||
Get trained attack model specify by config. | Get trained attack model specify by config. | ||||
@@ -123,6 +128,8 @@ def get_attack_model(features, labels, config): | |||||
params of each method must within the range of changeable parameters. | params of each method must within the range of changeable parameters. | ||||
Tips of params implement can be found in | Tips of params implement can be found in | ||||
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". | "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". | ||||
n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
otherwise the value of n_jobs must be a positive integer. | |||||
Returns: | Returns: | ||||
sklearn.BaseEstimator, trained model specify by config["method"]. | sklearn.BaseEstimator, trained model specify by config["method"]. | ||||
@@ -136,13 +143,13 @@ def get_attack_model(features, labels, config): | |||||
method = str.lower(config["method"]) | method = str.lower(config["method"]) | ||||
if method == "knn": | if method == "knn": | ||||
return _attack_knn(features, labels, config["params"]) | |||||
return _attack_knn(features, labels, config["params"], n_jobs) | |||||
if method == "lr": | if method == "lr": | ||||
return _attack_lr(features, labels, config["params"]) | |||||
return _attack_lr(features, labels, config["params"], n_jobs) | |||||
if method == "mlp": | if method == "mlp": | ||||
return _attack_mlpc(features, labels, config["params"]) | |||||
return _attack_mlpc(features, labels, config["params"], n_jobs) | |||||
if method == "rf": | if method == "rf": | ||||
return _attack_rf(features, labels, config["params"]) | |||||
return _attack_rf(features, labels, config["params"], n_jobs) | |||||
msg = "Method {} is not supported.".format(config["method"]) | msg = "Method {} is not supported.".format(config["method"]) | ||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
@@ -15,14 +15,16 @@ | |||||
Membership Inference | Membership Inference | ||||
""" | """ | ||||
from multiprocessing import cpu_count | |||||
import numpy as np | import numpy as np | ||||
import mindspore as ms | import mindspore as ms | ||||
from mindspore.train import Model | from mindspore.train import Model | ||||
from mindspore.dataset.engine import Dataset | from mindspore.dataset.engine import Dataset | ||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from .attacker import get_attack_model | |||||
from ._check_config import check_config_params | |||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
TAG = "MembershipInference" | TAG = "MembershipInference" | ||||
@@ -101,13 +103,15 @@ class MembershipInference: | |||||
Args: | Args: | ||||
model (Model): Target model. | model (Model): Target model. | ||||
n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | |||||
otherwise the value of n_jobs must be a positive integer. | |||||
Examples: | Examples: | ||||
>>> train_1, train_2 are non-overlapping datasets from training dataset of target model. | >>> train_1, train_2 are non-overlapping datasets from training dataset of target model. | ||||
>>> test_1, test_2 are non-overlapping datasets from test dataset of target model. | >>> test_1, test_2 are non-overlapping datasets from test dataset of target model. | ||||
>>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. | >>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. | ||||
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | ||||
>>> inference_model = MembershipInference(model) | |||||
>>> inference_model = MembershipInference(model, n_jobs=-1) | |||||
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | ||||
>>> inference_model.train(train_1, test_1, config) | >>> inference_model.train(train_1, test_1, config) | ||||
>>> metrics = ["precision", "recall", "accuracy"] | >>> metrics = ["precision", "recall", "accuracy"] | ||||
@@ -115,15 +119,26 @@ class MembershipInference: | |||||
Raises: | Raises: | ||||
TypeError: If type of model is not mindspore.train.Model. | TypeError: If type of model is not mindspore.train.Model. | ||||
TypeError: If type of n_jobs is not int. | |||||
ValueError: The value of n_jobs is neither -1 nor a positive integer. | |||||
""" | """ | ||||
def __init__(self, model): | |||||
def __init__(self, model, n_jobs=-1): | |||||
if not isinstance(model, Model): | if not isinstance(model, Model): | ||||
msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) | msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) | ||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise TypeError(msg) | raise TypeError(msg) | ||||
if not isinstance(n_jobs, int): | |||||
msg = "Type of parameter 'n_jobs' must be int, but got {}".format(type(n_jobs)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if not (n_jobs == -1 or n_jobs > 0): | |||||
msg = "Value of n_jobs must be either -1 or positive integer, but got {}.".format(n_jobs) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
self.model = model | self.model = model | ||||
self.n_jobs = min(n_jobs, cpu_count()) | |||||
self.method_list = ["knn", "lr", "mlp", "rf"] | self.method_list = ["knn", "lr", "mlp", "rf"] | ||||
self.attack_list = [] | self.attack_list = [] | ||||
@@ -162,24 +177,13 @@ class MembershipInference: | |||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise TypeError(msg) | raise TypeError(msg) | ||||
for config in attack_config: | |||||
if not isinstance(config, dict): | |||||
msg = "Type of each config in 'attack_config' must be dict, but got {}.".format(type(config)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if {"params", "method"} != set(config.keys()): | |||||
msg = "Each config in attack_config must have keys 'method' and 'params'," \ | |||||
"but your key value is {}.".format(set(config.keys())) | |||||
LOGGER.error(TAG, msg) | |||||
raise KeyError(msg) | |||||
if str.lower(config["method"]) not in self.method_list: | |||||
msg = "Method {} is not support.".format(config["method"]) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
check_config_params(attack_config) # Verify attack config. | |||||
features, labels = self._transform(dataset_train, dataset_test) | features, labels = self._transform(dataset_train, dataset_test) | ||||
for config in attack_config: | for config in attack_config: | ||||
self.attack_list.append(get_attack_model(features, labels, config)) | |||||
self.attack_list.append(get_attack_model(features, labels, config, n_jobs=self.n_jobs)) | |||||
def eval(self, dataset_train, dataset_test, metrics): | def eval(self, dataset_train, dataset_test, metrics): | ||||
""" | """ | ||||
@@ -35,7 +35,7 @@ def test_get_knn_model(): | |||||
"n_neighbors": [3], | "n_neighbors": [3], | ||||
} | } | ||||
} | } | ||||
knn_attacker = get_attack_model(features, labels, config_knn) | |||||
knn_attacker = get_attack_model(features, labels, config_knn, -1) | |||||
pred = knn_attacker.predict(features) | pred = knn_attacker.predict(features) | ||||
assert pred is not None | assert pred is not None | ||||
@@ -54,7 +54,7 @@ def test_get_lr_model(): | |||||
"C": np.logspace(-4, 2, 10), | "C": np.logspace(-4, 2, 10), | ||||
} | } | ||||
} | } | ||||
lr_attacker = get_attack_model(features, labels, config_lr) | |||||
lr_attacker = get_attack_model(features, labels, config_lr, -1) | |||||
pred = lr_attacker.predict(features) | pred = lr_attacker.predict(features) | ||||
assert pred is not None | assert pred is not None | ||||
@@ -75,7 +75,7 @@ def test_get_mlp_model(): | |||||
"alpha": [0.0001, 0.001, 0.01], | "alpha": [0.0001, 0.001, 0.01], | ||||
} | } | ||||
} | } | ||||
mlpc_attacker = get_attack_model(features, labels, config_mlpc) | |||||
mlpc_attacker = get_attack_model(features, labels, config_mlpc, -1) | |||||
pred = mlpc_attacker.predict(features) | pred = mlpc_attacker.predict(features) | ||||
assert pred is not None | assert pred is not None | ||||
@@ -98,6 +98,6 @@ def test_get_rf_model(): | |||||
"min_samples_leaf": [1, 2, 4], | "min_samples_leaf": [1, 2, 4], | ||||
} | } | ||||
} | } | ||||
rf_attacker = get_attack_model(features, labels, config_rf) | |||||
rf_attacker = get_attack_model(features, labels, config_rf, -1) | |||||
pred = rf_attacker.predict(features) | pred = rf_attacker.predict(features) | ||||
assert pred is not None | assert pred is not None |
@@ -24,6 +24,7 @@ import numpy as np | |||||
import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
from mindspore import nn | from mindspore import nn | ||||
from mindspore.train import Model | from mindspore.train import Model | ||||
import mindspore.context as context | |||||
from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference | from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference | ||||
@@ -31,6 +32,8 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) | |||||
from defenses.mock_net import Net | from defenses.mock_net import Net | ||||
context.set_context(mode=context.GRAPH_MODE) | |||||
def dataset_generator(batch_size, batches): | def dataset_generator(batch_size, batches): | ||||
"""mock training data.""" | """mock training data.""" | ||||
data = np.random.randn(batches*batch_size, 1, 32, 32).astype( | data = np.random.randn(batches*batch_size, 1, 32, 32).astype( | ||||
@@ -51,7 +54,7 @@ def test_get_membership_inference_object(): | |||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
model = Model(network=net, loss_fn=loss, optimizer=opt) | model = Model(network=net, loss_fn=loss, optimizer=opt) | ||||
inference_model = MembershipInference(model) | |||||
inference_model = MembershipInference(model, -1) | |||||
assert isinstance(inference_model, MembershipInference) | assert isinstance(inference_model, MembershipInference) | ||||
@@ -65,7 +68,7 @@ def test_membership_inference_object_train(): | |||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
model = Model(network=net, loss_fn=loss, optimizer=opt) | model = Model(network=net, loss_fn=loss, optimizer=opt) | ||||
inference_model = MembershipInference(model) | |||||
inference_model = MembershipInference(model, -1) | |||||
assert isinstance(inference_model, MembershipInference) | assert isinstance(inference_model, MembershipInference) | ||||
config = [{ | config = [{ | ||||
@@ -95,7 +98,7 @@ def test_membership_inference_eval(): | |||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
model = Model(network=net, loss_fn=loss, optimizer=opt) | model = Model(network=net, loss_fn=loss, optimizer=opt) | ||||
inference_model = MembershipInference(model) | |||||
inference_model = MembershipInference(model, -1) | |||||
assert isinstance(inference_model, MembershipInference) | assert isinstance(inference_model, MembershipInference) | ||||
batch_size = 16 | batch_size = 16 | ||||