|
|
@@ -23,8 +23,10 @@ from mindspore.train import Model |
|
|
|
from mindspore.dataset.engine import Dataset |
|
|
|
from mindspore import Tensor |
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
|
from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \ |
|
|
|
check_model, check_numpy_param |
|
|
|
from .attacker import get_attack_model |
|
|
|
from ._check_config import check_config_params |
|
|
|
from ._check_config import verify_config_params |
|
|
|
|
|
|
|
LOGGER = LogUtil.get_instance() |
|
|
|
TAG = "MembershipInference" |
|
|
@@ -47,23 +49,21 @@ def _eval_info(pred, truth, option): |
|
|
|
ValueError, size of parameter pred or truth is 0. |
|
|
|
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. |
|
|
|
""" |
|
|
|
if pred.size == 0 or truth.size == 0: |
|
|
|
msg = "Size of pred or truth is 0." |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
check_numpy_param("pred", pred) |
|
|
|
check_numpy_param("truth", truth) |
|
|
|
|
|
|
|
if option == "accuracy": |
|
|
|
count = np.sum(pred == truth) |
|
|
|
return count / len(pred) |
|
|
|
if option == "precision": |
|
|
|
count = np.sum(pred & truth) |
|
|
|
if np.sum(pred) == 0: |
|
|
|
return -1 |
|
|
|
count = np.sum(pred & truth) |
|
|
|
return count / np.sum(pred) |
|
|
|
if option == "recall": |
|
|
|
count = np.sum(pred & truth) |
|
|
|
if np.sum(truth) == 0: |
|
|
|
return -1 |
|
|
|
count = np.sum(pred & truth) |
|
|
|
return count / np.sum(truth) |
|
|
|
|
|
|
|
msg = "The metric value {} is undefined.".format(option) |
|
|
@@ -107,9 +107,9 @@ class MembershipInference: |
|
|
|
otherwise the value of n_jobs must be a positive integer. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> train_1, train_2 are non-overlapping datasets from training dataset of target model. |
|
|
|
>>> test_1, test_2 are non-overlapping datasets from test dataset of target model. |
|
|
|
>>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. |
|
|
|
>>> # train_1, train_2 are non-overlapping datasets from training dataset of target model. |
|
|
|
>>> # test_1, test_2 are non-overlapping datasets from test dataset of target model. |
|
|
|
>>> # We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. |
|
|
|
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) |
|
|
|
>>> inference_model = MembershipInference(model, n_jobs=-1) |
|
|
|
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] |
|
|
@@ -124,65 +124,44 @@ class MembershipInference: |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model, n_jobs=-1): |
|
|
|
if not isinstance(model, Model): |
|
|
|
msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
if not isinstance(n_jobs, int): |
|
|
|
msg = "Type of parameter 'n_jobs' must be int, but got {}".format(type(n_jobs)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
check_param_type("n_jobs", n_jobs, int) |
|
|
|
if not (n_jobs == -1 or n_jobs > 0): |
|
|
|
msg = "Value of n_jobs must be either -1 or positive integer, but got {}.".format(n_jobs) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
self.model = model |
|
|
|
self.n_jobs = min(n_jobs, cpu_count()) |
|
|
|
self.method_list = ["knn", "lr", "mlp", "rf"] |
|
|
|
self.attack_list = [] |
|
|
|
self._model = check_model("model", model, Model) |
|
|
|
self._n_jobs = min(n_jobs, cpu_count()) |
|
|
|
self._attack_list = [] |
|
|
|
|
|
|
|
def train(self, dataset_train, dataset_test, attack_config): |
|
|
|
""" |
|
|
|
Depending on the configuration, use the incoming data set to train the attack model. |
|
|
|
Save the attack model to self.attack_list. |
|
|
|
Depending on the configuration, use the input data set to train the attack model. |
|
|
|
Save the attack model to self._attack_list. |
|
|
|
|
|
|
|
Args: |
|
|
|
dataset_train (mindspore.dataset): The training dataset for the target model. |
|
|
|
dataset_test (mindspore.dataset): The test set for the target model. |
|
|
|
attack_config (list): Parameter setting for the attack model. The format is |
|
|
|
attack_config (Union[list, tuple]): Parameter setting for the attack model. The format is |
|
|
|
[{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}, |
|
|
|
{"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}]. |
|
|
|
The support methods list is in self.method_list, and the params of each method |
|
|
|
The support methods are knn, lr, mlp and rf, and the params of each method |
|
|
|
must within the range of changeable parameters. Tips of params implement |
|
|
|
can be found in |
|
|
|
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". |
|
|
|
|
|
|
|
Raises: |
|
|
|
KeyError: If each config in attack_config doesn't have keys {"method", "params"} |
|
|
|
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. |
|
|
|
KeyError: If any config in attack_config doesn't have keys {"method", "params"} |
|
|
|
NameError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. |
|
|
|
""" |
|
|
|
if not isinstance(dataset_train, Dataset): |
|
|
|
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if not isinstance(dataset_test, Dataset): |
|
|
|
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if not isinstance(attack_config, list): |
|
|
|
msg = "Type of parameter 'attack_config' must be list, but got {}.".format(type(attack_config)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
check_config_params(attack_config) # Verify attack config. |
|
|
|
check_param_type("dataset_train", dataset_train, Dataset) |
|
|
|
check_param_type("dataset_test", dataset_test, Dataset) |
|
|
|
check_param_multi_types("attack_config", attack_config, (list, tuple)) |
|
|
|
verify_config_params(attack_config) |
|
|
|
|
|
|
|
features, labels = self._transform(dataset_train, dataset_test) |
|
|
|
|
|
|
|
for config in attack_config: |
|
|
|
self.attack_list.append(get_attack_model(features, labels, config, n_jobs=self.n_jobs)) |
|
|
|
self._attack_list.append(get_attack_model(features, labels, config, n_jobs=self._n_jobs)) |
|
|
|
|
|
|
|
|
|
|
|
def eval(self, dataset_train, dataset_test, metrics): |
|
|
@@ -199,20 +178,9 @@ class MembershipInference: |
|
|
|
Returns: |
|
|
|
list, Each element contains an evaluation indicator for the attack model. |
|
|
|
""" |
|
|
|
if not isinstance(dataset_train, Dataset): |
|
|
|
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if not isinstance(dataset_test, Dataset): |
|
|
|
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if not isinstance(metrics, (list, tuple)): |
|
|
|
msg = "Type of parameter 'config' must be Union[list, tuple], but got {}.".format(type(metrics)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
check_param_type("dataset_train", dataset_train, Dataset) |
|
|
|
check_param_type("dataset_test", dataset_test, Dataset) |
|
|
|
check_param_multi_types("metrics", metrics, (list, tuple)) |
|
|
|
|
|
|
|
metrics = set(metrics) |
|
|
|
metrics_list = {"precision", "accuracy", "recall"} |
|
|
@@ -223,7 +191,7 @@ class MembershipInference: |
|
|
|
|
|
|
|
result = [] |
|
|
|
features, labels = self._transform(dataset_train, dataset_test) |
|
|
|
for attacker in self.attack_list: |
|
|
|
for attacker in self._attack_list: |
|
|
|
pred = attacker.predict(features) |
|
|
|
item = {} |
|
|
|
for option in metrics: |
|
|
@@ -233,7 +201,7 @@ class MembershipInference: |
|
|
|
|
|
|
|
def _transform(self, dataset_train, dataset_test): |
|
|
|
""" |
|
|
|
Generate corresponding loss_logits feature and new label, and return after shuffle. |
|
|
|
Generate corresponding loss_logits features and new label, and return after shuffle. |
|
|
|
|
|
|
|
Args: |
|
|
|
dataset_train: The training set for the target model. |
|
|
@@ -255,13 +223,13 @@ class MembershipInference: |
|
|
|
|
|
|
|
return features, labels |
|
|
|
|
|
|
|
def _generate(self, dataset_x, label): |
|
|
|
def _generate(self, input_dataset, label): |
|
|
|
""" |
|
|
|
Return a loss_logits features and labels for training attack model. |
|
|
|
|
|
|
|
Args: |
|
|
|
dataset_x (mindspore.dataset): The dataset to be generate. |
|
|
|
label (int32): Whether dataset_x belongs to the target model. |
|
|
|
input_dataset (mindspore.dataset): The dataset to be generate. |
|
|
|
label (int32): Whether input_dataset belongs to the target model. |
|
|
|
|
|
|
|
Returns: |
|
|
|
- numpy.ndarray, Loss_logits features for each sample. Shape is (N, C). |
|
|
@@ -269,10 +237,10 @@ class MembershipInference: |
|
|
|
- numpy.ndarray, Labels for each sample, Shape is (N,). |
|
|
|
""" |
|
|
|
loss_logits = np.array([]) |
|
|
|
for batch in dataset_x.create_dict_iterator(): |
|
|
|
for batch in input_dataset.create_dict_iterator(): |
|
|
|
batch_data = Tensor(batch['image'], ms.float32) |
|
|
|
batch_labels = batch['label'].astype(np.int32) |
|
|
|
batch_logits = self.model.predict(batch_data).asnumpy() |
|
|
|
batch_logits = self._model.predict(batch_data).asnumpy() |
|
|
|
batch_loss = _softmax_cross_entropy(batch_logits, batch_labels) |
|
|
|
|
|
|
|
batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) |
|
|
|