Browse Source

!100 fixed exception detection

Merge pull request !100 from liuluobin/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
12d4873180
1 changed files with 21 additions and 14 deletions
  1. +21
    -14
      mindarmour/diff_privacy/evaluation/membership_inference.py

+ 21
- 14
mindarmour/diff_privacy/evaluation/membership_inference.py View File

@@ -20,8 +20,6 @@ import numpy as np
import mindspore as ms import mindspore as ms
from mindspore.train import Model from mindspore.train import Model
from mindspore.dataset.engine import Dataset from mindspore.dataset.engine import Dataset
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
@@ -71,6 +69,22 @@ def _eval_info(pred, truth, option):
raise ValueError(msg) raise ValueError(msg)




def _softmax_cross_entropy(logits, labels):
"""
Calculate the SoftmaxCrossEntropy result between logits and labels.

Args:
logits (numpy.ndarray): Numpy array of shape(N, C).
labels (numpy.ndarray): Numpy array of shape(N, )

Returns:
numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits.
"""
labels = np.eye(logits.shape[1])[labels].astype(np.int32)
logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
return -1*np.sum(labels*np.log(logits), axis=1)


class MembershipInference: class MembershipInference:
""" """
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack.
@@ -192,8 +206,8 @@ class MembershipInference:
raise TypeError(msg) raise TypeError(msg)


metrics = set(metrics) metrics = set(metrics)
metrics_list = {"precision", "accruacy", "recall"}
if metrics > metrics_list:
metrics_list = {"precision", "accuracy", "recall"}
if not metrics <= metrics_list:
msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics) msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)
@@ -244,19 +258,12 @@ class MembershipInference:
N is the number of sample. C = 1 + dim(logits). N is the number of sample. C = 1 + dim(logits).
- numpy.ndarray, Labels for each sample, Shape is (N,). - numpy.ndarray, Labels for each sample, Shape is (N,).
""" """
if context.get_context("device_target") != "Ascend":
msg = "The target device must be Ascend, " \
"but current is {}.".format(context.get_context("device_target"))
LOGGER.error(TAG, msg)
raise RuntimeError(msg)
loss_logits = np.array([]) loss_logits = np.array([])
for batch in dataset_x.create_dict_iterator(): for batch in dataset_x.create_dict_iterator():
batch_data = Tensor(batch['image'], ms.float32) batch_data = Tensor(batch['image'], ms.float32)
batch_labels = Tensor(batch['label'], ms.int32)
batch_logits = self.model.predict(batch_data)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None)
batch_loss = loss(batch_logits, batch_labels).asnumpy()
batch_logits = batch_logits.asnumpy()
batch_labels = batch['label'].astype(np.int32)
batch_logits = self.model.predict(batch_data).asnumpy()
batch_loss = _softmax_cross_entropy(batch_logits, batch_labels)


batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits))
if loss_logits.size == 0: if loss_logits.size == 0:


Loading…
Cancel
Save