@@ -67,36 +67,55 @@ def mnist_inversion_attack(net): | |||||
load_dict = load_checkpoint(ckpt_path) | load_dict = load_checkpoint(ckpt_path) | ||||
load_param_into_net(net, load_dict) | load_param_into_net(net, load_dict) | ||||
# get test data | |||||
data_list = "../../common/dataset/MNIST/test" | |||||
# get original data and their inferred fearures | |||||
data_list = "../../common/dataset/MNIST/train" | |||||
batch_size = 32 | batch_size = 32 | ||||
ds = generate_mnist_dataset(data_list, batch_size) | ds = generate_mnist_dataset(data_list, batch_size) | ||||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||||
i = 0 | i = 0 | ||||
batch_num = 1 | batch_num = 1 | ||||
sample_num = 10 | |||||
sample_num = 30 | |||||
for data in ds.create_tuple_iterator(output_numpy=True): | for data in ds.create_tuple_iterator(output_numpy=True): | ||||
i += 1 | i += 1 | ||||
images = data[0].astype(np.float32) | images = data[0].astype(np.float32) | ||||
target_features = net(Tensor(images)).asnumpy() | |||||
true_labels = data[1][: sample_num] | |||||
target_features = net(Tensor(images)).asnumpy()[:sample_num] | |||||
original_images = images[: sample_num] | original_images = images[: sample_num] | ||||
inversion_images = inversion_attack.generate(target_features[:sample_num], iters=100) | |||||
for n in range(1, sample_num+1): | |||||
plt.subplot(2, sample_num, n) | |||||
plt.gray() | |||||
plt.imshow(images[n - 1].reshape(32, 32)) | |||||
plt.subplot(2, sample_num, n + sample_num) | |||||
plt.gray() | |||||
plt.imshow(inversion_images[n - 1].reshape(32, 32)) | |||||
plt.show() | |||||
if i >= batch_num: | if i >= batch_num: | ||||
break | break | ||||
# evaluate the similarity between inversion images and original images | |||||
avg_l2_dis, avg_ssim = inversion_attack.evaluate(original_images, inversion_images) | |||||
LOGGER.info(TAG, 'The average L2 distance between original images and inversion images is: {}'.format(avg_l2_dis)) | |||||
LOGGER.info(TAG, 'The average ssim value between original images and inversion images is: {}'.format(avg_ssim)) | |||||
# run attacking | |||||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.1, 5]) | |||||
inversion_images = inversion_attack.generate(target_features, iters=100) | |||||
# get the predict results of inversion images on a new trained model | |||||
net2 = LeNet5() | |||||
new_ckpt_path = '../../common/networks/lenet5/new_trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
new_load_dict = load_checkpoint(new_ckpt_path) | |||||
load_param_into_net(net2, new_load_dict) | |||||
pred_labels = np.argmax(net2(Tensor(inversion_images).astype(np.float32)).asnumpy(), axis=1) | |||||
# evaluate the quality of inversion images | |||||
avg_l2_dis, avg_ssim, avg_confi = inversion_attack.evaluate(original_images, inversion_images, true_labels, net2) | |||||
LOGGER.info(TAG, 'The average L2 distance between original images and inverted images is: {}'.format(avg_l2_dis)) | |||||
LOGGER.info(TAG, 'The average ssim value between original images and inverted images is: {}'.format(avg_ssim)) | |||||
LOGGER.info(TAG, 'The average prediction confidence on true labels of inverted images is: {}'.format(avg_confi)) | |||||
LOGGER.info(TAG, 'True labels of original images are: %s' % true_labels) | |||||
LOGGER.info(TAG, 'Predicted labels of inverted images are: %s' % pred_labels) | |||||
# plot 10 images | |||||
plot_num = min(sample_num, 10) | |||||
for n in range(1, plot_num+1): | |||||
plt.subplot(2, plot_num, n) | |||||
if n == 1: | |||||
plt.title('Original images', fontsize=16, loc='left') | |||||
plt.gray() | |||||
plt.imshow(images[n - 1].reshape(32, 32)) | |||||
plt.subplot(2, plot_num, n + plot_num) | |||||
if n == 1: | |||||
plt.title('Inverted images', fontsize=16, loc='left') | |||||
plt.gray() | |||||
plt.imshow(inversion_images[n - 1].reshape(32, 32)) | |||||
plt.show() | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
@@ -15,6 +15,7 @@ | |||||
Inversion Attack | Inversion Attack | ||||
""" | """ | ||||
import numpy as np | import numpy as np | ||||
from scipy.special import softmax | |||||
from mindspore.nn import Cell, MSELoss | from mindspore.nn import Cell, MSELoss | ||||
from mindspore import Tensor | from mindspore import Tensor | ||||
@@ -166,18 +167,24 @@ class ImageInversionAttack: | |||||
inversion_images.append(inversion_image_n) | inversion_images.append(inversion_image_n) | ||||
return np.concatenate(np.array(inversion_images)) | return np.concatenate(np.array(inversion_images)) | ||||
def evaluate(self, original_images, inversion_images): | |||||
def evaluate(self, original_images, inversion_images, labels=None, new_network=None): | |||||
""" | """ | ||||
Compute the average L2 distance and SSIM value between original images and inversion images. | |||||
Evaluate the quality of inverted images by three index: the average L2 distance and SSIM value between | |||||
original images and inversion images, and the average of inverted images' confidence on true labels of inverted | |||||
inferred by a new trained network. | |||||
Args: | Args: | ||||
original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width, | original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width, | ||||
img_height). | img_height). | ||||
inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width, | inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width, | ||||
img_height). | img_height). | ||||
labels (numpy.ndarray): Ground truth labels of original images. Default: None. | |||||
new_network (Cell): A network whose structure contains all parts of self._network, but loaded with different | |||||
checkpoint file. Default: None. | |||||
Returns: | Returns: | ||||
tuple, the average l2 distance and average ssim value between original images and inversion images. | |||||
tuple, average l2 distance, average ssim value and average confidence (if labels or new_network is None, | |||||
then average confidence would be None). | |||||
Examples: | Examples: | ||||
>>> net = LeNet5() | >>> net = LeNet5() | ||||
@@ -188,15 +195,31 @@ class ImageInversionAttack: | |||||
>>> ori_images = np.random.random((2, 1, 32, 32)) | >>> ori_images = np.random.random((2, 1, 32, 32)) | ||||
>>> result = inversion_attack.evaluate(ori_images, inver_images) | >>> result = inversion_attack.evaluate(ori_images, inver_images) | ||||
>>> print(len(result)) | >>> print(len(result)) | ||||
2 | |||||
3 | |||||
""" | """ | ||||
check_numpy_param('original_images', original_images) | check_numpy_param('original_images', original_images) | ||||
check_numpy_param('inversion_images', inversion_images) | check_numpy_param('inversion_images', inversion_images) | ||||
if labels is not None: | |||||
check_numpy_param('labels', labels) | |||||
true_labels = np.squeeze(labels) | |||||
if len(true_labels.shape) > 1: | |||||
msg = 'Shape of true_labels should be (1, n) or (n,), but got {}'.format(true_labels.shape) | |||||
raise ValueError(msg) | |||||
if true_labels.size != original_images.shape[0]: | |||||
msg = 'The size of true_labels should equal the number of images, but got {} and {}'.format( | |||||
true_labels.size, original_images.shape[0]) | |||||
raise ValueError(msg) | |||||
if new_network is not None: | |||||
check_param_type('new_network', new_network, Cell) | |||||
LOGGER.info(TAG, 'Please make sure that the network you pass is loaded with different checkpoint files ' | |||||
'compared with that of self._network.') | |||||
img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images) | img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images) | ||||
if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3): | if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3): | ||||
msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \ | msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \ | ||||
' but got {} and {}'.format(img_1.shape, img_2.shape) | ' but got {} and {}'.format(img_1.shape, img_2.shape) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
total_l2_distance = 0 | total_l2_distance = 0 | ||||
total_ssim = 0 | total_ssim = 0 | ||||
img_1 = img_1.transpose(0, 2, 3, 1) | img_1 = img_1.transpose(0, 2, 3, 1) | ||||
@@ -207,4 +230,9 @@ class ImageInversionAttack: | |||||
total_ssim += compute_ssim(img_1[i], img_2[i]) | total_ssim += compute_ssim(img_1[i], img_2[i]) | ||||
avg_l2_dis = total_l2_distance / img_1.shape[0] | avg_l2_dis = total_l2_distance / img_1.shape[0] | ||||
avg_ssim = total_ssim / img_1.shape[0] | avg_ssim = total_ssim / img_1.shape[0] | ||||
return avg_l2_dis, avg_ssim | |||||
avg_confi = None | |||||
if (new_network is not None) and (labels is not None): | |||||
pred_logits = new_network(Tensor(inversion_images.astype(np.float32))).asnumpy() | |||||
logits_softmax = softmax(pred_logits, axis=1) | |||||
avg_confi = np.mean(logits_softmax[np.arange(img_1.shape[0]), true_labels]) | |||||
return avg_l2_dis, avg_ssim, avg_confi |
@@ -42,3 +42,20 @@ def test_inversion_attack(): | |||||
avg_ssim = inversion_attack.evaluate(original_images, inversion_images) | avg_ssim = inversion_attack.evaluate(original_images, inversion_images) | ||||
assert 0 < avg_ssim[1] < 1 | assert 0 < avg_ssim[1] < 1 | ||||
assert target_features.shape[0] == inversion_images.shape[0] | assert target_features.shape[0] == inversion_images.shape[0] | ||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_inversion_attack2(): | |||||
net = Net() | |||||
original_images = np.random.random((2, 1, 32, 32)).astype(np.float32) | |||||
target_features = np.random.random((2, 10)).astype(np.float32) | |||||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||||
inversion_images = inversion_attack.generate(target_features, iters=10) | |||||
true_labels = np.array([1, 2]) | |||||
new_net = Net() | |||||
indexes = inversion_attack.evaluate(original_images, inversion_images, true_labels, new_net) | |||||
assert len(indexes) == 3 |