@@ -42,7 +42,7 @@ python train.py --data_path home_path_to_cifar100 --ckpt_path ./ | |||||
python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt | python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt | ||||
``` | ``` | ||||
## 4. suppress privacy training | |||||
## 4. Suppress privacy training | |||||
With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected | With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected | ||||
layers) are set to zero as the training step grows, which can | layers) are set to zero as the training step grows, which can | ||||
@@ -52,3 +52,15 @@ With suppress privacy mechanism, the values of some trainable parameters (such | |||||
cd examples/privacy/sup_privacy | cd examples/privacy/sup_privacy | ||||
python sup_privacy.py | python sup_privacy.py | ||||
``` | ``` | ||||
## 5. Image inversion attack | |||||
Inversion attack means reconstructing an image based on its deep representations. For example, | |||||
reconstruct a MNIST image based on its output through LeNet5. The mechanism behind it is that well-trained | |||||
model can "remember" those training dataset. Therefore, inversion attack can be used to estimate the privacy | |||||
leakage of training tasks. | |||||
```sh | |||||
cd examples/privacy/inversion_attack | |||||
python mnist_inversion_attack.py | |||||
``` |
@@ -0,0 +1,108 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
""" | |||||
Examples of image inversion attack | |||||
""" | |||||
import numpy as np | |||||
import matplotlib.pyplot as plt | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
from mindspore import Tensor, context | |||||
from mindspore import nn | |||||
from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack | |||||
from mindarmour.utils.logger import LogUtil | |||||
from examples.common.networks.lenet5.lenet5_net import LeNet5, conv, fc_with_initialize | |||||
from examples.common.dataset.data_processing import generate_mnist_dataset | |||||
LOGGER = LogUtil.get_instance() | |||||
LOGGER.set_level('INFO') | |||||
TAG = 'InversionAttack' | |||||
# pylint: disable=invalid-name | |||||
class LeNet5_part(nn.Cell): | |||||
""" | |||||
Part of LeNet5 network. | |||||
""" | |||||
def __init__(self): | |||||
super(LeNet5_part, self).__init__() | |||||
self.conv1 = conv(1, 6, 5) | |||||
self.conv2 = conv(6, 16, 5) | |||||
self.fc1 = fc_with_initialize(16*5*5, 120) | |||||
self.fc2 = fc_with_initialize(120, 84) | |||||
self.fc3 = fc_with_initialize(84, 10) | |||||
self.relu = nn.ReLU() | |||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
self.flatten = nn.Flatten() | |||||
def construct(self, x): | |||||
x = self.conv1(x) | |||||
x = self.relu(x) | |||||
x = self.max_pool2d(x) | |||||
x = self.conv2(x) | |||||
x = self.relu(x) | |||||
x = self.max_pool2d(x) | |||||
return x | |||||
def mnist_inversion_attack(net): | |||||
""" | |||||
Image inversion attack based on LeNet5 and MNIST dataset. | |||||
""" | |||||
# upload trained network | |||||
ckpt_path = '../../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
load_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, load_dict) | |||||
# get test data | |||||
data_list = "../../common/dataset/MNIST/test" | |||||
batch_size = 32 | |||||
ds = generate_mnist_dataset(data_list, batch_size) | |||||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||||
i = 0 | |||||
batch_num = 1 | |||||
sample_num = 10 | |||||
for data in ds.create_tuple_iterator(output_numpy=True): | |||||
i += 1 | |||||
images = data[0].astype(np.float32) | |||||
target_features = net(Tensor(images)).asnumpy() | |||||
original_images = images[: sample_num] | |||||
inversion_images = inversion_attack.generate(target_features[:sample_num], iters=100) | |||||
for n in range(1, sample_num+1): | |||||
plt.subplot(2, sample_num, n) | |||||
plt.gray() | |||||
plt.imshow(images[n - 1].reshape(32, 32)) | |||||
plt.subplot(2, sample_num, n + sample_num) | |||||
plt.gray() | |||||
plt.imshow(inversion_images[n - 1].reshape(32, 32)) | |||||
plt.show() | |||||
if i >= batch_num: | |||||
break | |||||
# evaluate the similarity between inversion images and original images | |||||
avg_l2_dis, avg_ssim = inversion_attack.evaluate(original_images, inversion_images) | |||||
LOGGER.info(TAG, 'The average L2 distance between original images and inversion images is: {}'.format(avg_l2_dis)) | |||||
LOGGER.info(TAG, 'The average ssim value between original images and inversion images is: {}'.format(avg_ssim)) | |||||
if __name__ == '__main__': | |||||
# device_target can be "CPU", "GPU" or "Ascend" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
# attack based on complete LeNet5 | |||||
mnist_inversion_attack(LeNet5()) | |||||
# attack based on part of LeNet5. The network is more shallower and can lead to a better attack result | |||||
mnist_inversion_attack(LeNet5_part()) |
@@ -9,6 +9,10 @@ from .adv_robustness.detectors.detector import Detector | |||||
from .fuzz_testing.fuzzing import Fuzzer | from .fuzz_testing.fuzzing import Fuzzer | ||||
from .privacy.diff_privacy import DPModel | from .privacy.diff_privacy import DPModel | ||||
from .privacy.evaluation.membership_inference import MembershipInference | from .privacy.evaluation.membership_inference import MembershipInference | ||||
from .privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl | |||||
from .privacy.sup_privacy.train.model import SuppressModel | |||||
from .privacy.sup_privacy.mask_monitor.masker import SuppressMasker | |||||
from .privacy.evaluation.inversion_attack import ImageInversionAttack | |||||
__all__ = ['Attack', | __all__ = ['Attack', | ||||
'BlackModel', | 'BlackModel', | ||||
@@ -16,4 +20,8 @@ __all__ = ['Attack', | |||||
'Defense', | 'Defense', | ||||
'Fuzzer', | 'Fuzzer', | ||||
'DPModel', | 'DPModel', | ||||
'MembershipInference'] | |||||
'MembershipInference', | |||||
'SuppressModel', | |||||
'SuppressCtrl', | |||||
'SuppressMasker', | |||||
'ImageInversionAttack'] |
@@ -17,84 +17,15 @@ Attack evaluation. | |||||
import numpy as np | import numpy as np | ||||
from scipy.ndimage.filters import convolve | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from mindarmour.utils._check_param import check_pair_numpy_param, \ | from mindarmour.utils._check_param import check_pair_numpy_param, \ | ||||
check_param_type, check_numpy_param, check_equal_shape | |||||
check_param_type, check_numpy_param | |||||
from mindarmour.utils.util import calculate_lp_distance, compute_ssim | |||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
TAG = 'AttackEvaluate' | TAG = 'AttackEvaluate' | ||||
def _compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11): | |||||
""" | |||||
compute structural similarity. | |||||
Args: | |||||
img_1 (numpy.ndarray): The first image to be compared. | |||||
img_2 (numpy.ndarray): The second image to be compared. | |||||
kernel_sigma (float): Gassian kernel param. Default: 1.5. | |||||
kernel_width (int): Another Gassian kernel param. Default: 11. | |||||
Returns: | |||||
float, structural similarity. | |||||
""" | |||||
img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2) | |||||
if len(img_1.shape) > 2: | |||||
total_ssim = 0 | |||||
for i in range(img_1.shape[2]): | |||||
total_ssim += _compute_ssim(img_1[:, :, i], img_2[:, :, i]) | |||||
return total_ssim / 3 | |||||
# Create gaussian kernel | |||||
gaussian_kernel = np.zeros((kernel_width, kernel_width)) | |||||
for i in range(kernel_width): | |||||
for j in range(kernel_width): | |||||
gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp( | |||||
- (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2))) | |||||
img_1 = img_1.astype(np.float32) | |||||
img_2 = img_2.astype(np.float32) | |||||
img_sq_1 = img_1**2 | |||||
img_sq_2 = img_2**2 | |||||
img_12 = img_1*img_2 | |||||
# Mean | |||||
img_mu_1 = convolve(img_1, gaussian_kernel) | |||||
img_mu_2 = convolve(img_2, gaussian_kernel) | |||||
# Mean square | |||||
img_mu_sq_1 = img_mu_1**2 | |||||
img_mu_sq_2 = img_mu_2**2 | |||||
img_mu_12 = img_mu_1*img_mu_2 | |||||
# Variances | |||||
img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel) | |||||
img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel) | |||||
# Covariance | |||||
img_sigma_12 = convolve(img_12, gaussian_kernel) | |||||
# Centered squares of variances | |||||
img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1 | |||||
img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2 | |||||
img_sigma_12 = img_sigma_12 - img_mu_12 | |||||
k_1 = 0.01 | |||||
k_2 = 0.03 | |||||
c_1 = (k_1*255)**2 | |||||
c_2 = (k_2*255)**2 | |||||
# Calculate ssim | |||||
num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2) | |||||
den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1 | |||||
+ img_sigma_sq_2 + c_2) | |||||
res = np.average(num_ssim / den_ssim) | |||||
return res | |||||
class AttackEvaluate: | class AttackEvaluate: | ||||
""" | """ | ||||
Evaluation metrics of attack methods. | Evaluation metrics of attack methods. | ||||
@@ -217,16 +148,11 @@ class AttackEvaluate: | |||||
l0_dist = 0 | l0_dist = 0 | ||||
l2_dist = 0 | l2_dist = 0 | ||||
linf_dist = 0 | linf_dist = 0 | ||||
avoid_zero_div = 1e-14 | |||||
for i in idxes: | for i in idxes: | ||||
diff = (self._adv_inputs[i] - self._inputs[i]).flatten() | |||||
data = self._inputs[i].flatten() | |||||
l0_dist += np.linalg.norm(diff, ord=0) \ | |||||
/ (np.linalg.norm(data, ord=0) + avoid_zero_div) | |||||
l2_dist += np.linalg.norm(diff, ord=2) \ | |||||
/ (np.linalg.norm(data, ord=2) + avoid_zero_div) | |||||
linf_dist += np.linalg.norm(diff, ord=np.inf) \ | |||||
/ (np.linalg.norm(data, ord=np.inf) + avoid_zero_div) | |||||
l0_dist_i, l2_dist_i, linf_dist_i = calculate_lp_distance(self._inputs[i], self._adv_inputs[i]) | |||||
l0_dist += l0_dist_i | |||||
l2_dist += l2_dist_i | |||||
linf_dist += linf_dist_i | |||||
return l0_dist / success_num, l2_dist / success_num, \ | return l0_dist / success_num, l2_dist / success_num, \ | ||||
linf_dist / success_num | linf_dist / success_num | ||||
@@ -249,7 +175,7 @@ class AttackEvaluate: | |||||
total_ssim = 0.0 | total_ssim = 0.0 | ||||
for _, i in enumerate(self._success_idxes): | for _, i in enumerate(self._success_idxes): | ||||
total_ssim += _compute_ssim(self._adv_inputs[i], self._inputs[i]) | |||||
total_ssim += compute_ssim(self._adv_inputs[i], self._inputs[i]) | |||||
return total_ssim / success_num | return total_ssim / success_num | ||||
@@ -0,0 +1,210 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Inversion Attack | |||||
""" | |||||
import numpy as np | |||||
from mindspore.nn import Cell, MSELoss | |||||
from mindspore import Tensor | |||||
from mindspore.ops import operations as P | |||||
from mindarmour.utils.util import GradWrapWithLoss | |||||
from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \ | |||||
check_int_positive, check_numpy_param, check_value_positive, check_equal_shape | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.utils.util import calculate_lp_distance, compute_ssim | |||||
LOGGER = LogUtil.get_instance() | |||||
LOGGER.set_level('INFO') | |||||
TAG = 'Image inversion attack' | |||||
class InversionLoss(Cell): | |||||
""" | |||||
The loss function for inversion attack. | |||||
Args: | |||||
network (Cell): The network used to infer images' deep representations. | |||||
weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to | |||||
obtain better results. | |||||
""" | |||||
def __init__(self, network, weights): | |||||
super(InversionLoss, self).__init__() | |||||
self._network = check_param_type('network', network, Cell) | |||||
self._mse_loss = MSELoss() | |||||
self._weights = check_param_multi_types('weights', weights, [list, tuple]) | |||||
self._get_shape = P.Shape() | |||||
def construct(self, input_data, target_features): | |||||
""" | |||||
Calculate the inversion attack loss, which consists of three parts. Loss_1 is for evaluating the difference | |||||
between the target deep representations and current representations; Loss_2 is for evaluating the continuity | |||||
between adjacent pixels; Loss_3 is for regularization. | |||||
Args: | |||||
input_data (Tensor): The reconstructed image during inversion attack. | |||||
target_features (Tensor): Deep representations of the original image. | |||||
Returns: | |||||
Tensor, inversion attack loss of the current iteration. | |||||
""" | |||||
output = self._network(input_data) | |||||
loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) | |||||
data_shape = self._get_shape(input_data) | |||||
split_op_1 = P.Split(2, data_shape[2]) | |||||
split_op_2 = P.Split(3, data_shape[3]) | |||||
data_split_1 = split_op_1(input_data) | |||||
data_split_2 = split_op_2(input_data) | |||||
loss_2 = 0 | |||||
for i in range(1, data_shape[2]): | |||||
loss_2 += self._mse_loss(data_split_1[i], data_split_1[i-1]) | |||||
for j in range(1, data_shape[3]): | |||||
loss_2 += self._mse_loss(data_split_2[j], data_split_2[j-1]) | |||||
loss_3 = self._mse_loss(input_data, 0) | |||||
loss = loss_1*self._weights[0] + loss_2*self._weights[1] + loss_3*self._weights[2] | |||||
return loss | |||||
class ImageInversionAttack: | |||||
""" | |||||
An attack method used to reconstruct images by inverting their deep representations. | |||||
References: `Aravindh Mahendran, Andrea Vedaldi. Understanding Deep Image Representations by Inverting Them. | |||||
2014. <https://arxiv.org/pdf/1412.0035.pdf>`_ | |||||
Args: | |||||
network (Cell): The network used to infer images' deep representations. | |||||
input_shape (tuple): Data shape of single network input, which should be in accordance with the given | |||||
network. The format of shape should be (channel, image_width, image_height). | |||||
input_bound (Union[tuple, list]): The pixel range of original images, which should be like [minimum_pixel, | |||||
maximum_pixel] or (minimum_pixel, maximum_pixel). | |||||
loss_weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to | |||||
obtain better results. Default: (1, 0.2, 5). | |||||
Raises: | |||||
TypeError: If the type of network is not Cell. | |||||
ValueError: If any value of input_shape is not positive int. | |||||
ValueError: If any value of loss_weights is not positive value. | |||||
""" | |||||
def __init__(self, network, input_shape, input_bound, loss_weights=(1, 0.2, 5)): | |||||
self._network = check_param_type('network', network, Cell) | |||||
for sub_loss_weight in loss_weights: | |||||
check_value_positive('sub_loss_weight', sub_loss_weight) | |||||
self._loss = InversionLoss(self._network, loss_weights) | |||||
self._input_shape = check_param_multi_types('input_shape', input_shape, [list, tuple]) | |||||
for shape_dim in input_shape: | |||||
check_int_positive('shape_dim', shape_dim) | |||||
self._input_bound = check_param_multi_types('input_bound', input_bound, [list, tuple]) | |||||
def generate(self, target_features, iters=100): | |||||
""" | |||||
Reconstruct images based on target_features. | |||||
Args: | |||||
target_features (numpy.ndarray): Deep representations of original images. The first dimension of | |||||
target_features should be img_num. It should be noted that the shape of target_features should be | |||||
(1, dim2, dim3, ...) if img_num equals 1. | |||||
iters (int): iteration times of inversion attack, which should be positive integers. Default: 100. | |||||
Returns: | |||||
numpy.ndarray, reconstructed images, which are expected to be similar to original images. | |||||
Raises: | |||||
TypeError: If the type of target_features is not numpy.ndarray. | |||||
ValueError: If any value of iters is not positive int.Z | |||||
Examples: | |||||
>>> net = LeNet5() | |||||
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), | |||||
>>> loss_weights=[1, 0.2, 5]) | |||||
>>> features = np.random.random((2, 10)).astype(np.float32) | |||||
>>> images = inversion_attack.generate(features, iters=10) | |||||
>>> print(images.shape) | |||||
(2, 1, 32, 32) | |||||
""" | |||||
target_features = check_numpy_param('target_features', target_features) | |||||
iters = check_int_positive('iters', iters) | |||||
# shape checking | |||||
img_num = target_features.shape[0] | |||||
test_input = np.random.random((img_num,) + self._input_shape).astype(np.float32) | |||||
test_out = self._network(Tensor(test_input)).asnumpy() | |||||
if test_out.shape != target_features.shape: | |||||
msg = "The shape of target_features ({}) is not in accordance with the shape" \ | |||||
" of network output({})".format(target_features.shape, test_out.shape) | |||||
raise ValueError(msg) | |||||
loss_net = self._loss | |||||
loss_grad = GradWrapWithLoss(loss_net) | |||||
inversion_images = [] | |||||
for i in range(img_num): | |||||
target_feature_n = target_features[i] | |||||
inversion_image_n = np.random.random((1,) + self._input_shape).astype(np.float32)*0.05 | |||||
for s in range(iters): | |||||
x_grad = loss_grad(Tensor(inversion_image_n), Tensor(target_feature_n)).asnumpy() | |||||
x_grad_sign = np.sign(x_grad) | |||||
inversion_image_n -= x_grad_sign*0.01 | |||||
inversion_image_n = np.clip(inversion_image_n, self._input_bound[0], self._input_bound[1]) | |||||
current_loss = self._loss(Tensor(inversion_image_n), Tensor(target_feature_n)) | |||||
LOGGER.info(TAG, 'iteration step: {}, loss is {}'.format(s, current_loss)) | |||||
inversion_images.append(inversion_image_n) | |||||
return np.concatenate(np.array(inversion_images)) | |||||
def evaluate(self, original_images, inversion_images): | |||||
""" | |||||
Compute the average L2 distance and SSIM value between original images and inversion images. | |||||
Args: | |||||
original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width, | |||||
img_height). | |||||
inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width, | |||||
img_height). | |||||
Returns: | |||||
tuple, the average l2 distance and average ssim value between original images and inversion images. | |||||
Examples: | |||||
>>> net = LeNet5() | |||||
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), | |||||
>>> loss_weights=[1, 0.2, 5]) | |||||
>>> features = np.random.random((2, 10)).astype(np.float32) | |||||
>>> inver_images = inversion_attack.generate(features, iters=10) | |||||
>>> ori_images = np.random.random((2, 1, 32, 32)) | |||||
>>> result = inversion_attack.evaluate(ori_images, inver_images) | |||||
>>> print(len(result)) | |||||
2 | |||||
""" | |||||
check_numpy_param('original_images', original_images) | |||||
check_numpy_param('inversion_images', inversion_images) | |||||
img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images) | |||||
if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3): | |||||
msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \ | |||||
' but got {} and {}'.format(img_1.shape, img_2.shape) | |||||
raise ValueError(msg) | |||||
total_l2_distance = 0 | |||||
total_ssim = 0 | |||||
img_1 = img_1.transpose(0, 2, 3, 1) | |||||
img_2 = img_2.transpose(0, 2, 3, 1) | |||||
for i in range(img_1.shape[0]): | |||||
_, l2_dis, _ = calculate_lp_distance(img_1[i], img_2[i]) | |||||
total_l2_distance += l2_dis | |||||
total_ssim += compute_ssim(img_1[i], img_2[i]) | |||||
avg_l2_dis = total_l2_distance / img_1.shape[0] | |||||
avg_ssim = total_ssim / img_1.shape[0] | |||||
return avg_l2_dis, avg_ssim |
@@ -61,6 +61,11 @@ def check_param_multi_types(arg_name, arg_value, valid_types): | |||||
def check_int_positive(arg_name, arg_value): | def check_int_positive(arg_name, arg_value): | ||||
"""Check positive integer.""" | """Check positive integer.""" | ||||
# 'True' is treated as int(1) in python, which is a bug. | |||||
if isinstance(arg_value, bool): | |||||
msg = '{} should not be bool value, but got {}'.format(arg_name, arg_value) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
arg_value = check_param_type(arg_name, arg_value, int) | arg_value = check_param_type(arg_name, arg_value, int) | ||||
if arg_value <= 0: | if arg_value <= 0: | ||||
msg = '{} must be greater than 0, but got {}'.format(arg_name, | msg = '{} must be greater than 0, but got {}'.format(arg_name, | ||||
@@ -13,11 +13,13 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
""" Util for MindArmour. """ | """ Util for MindArmour. """ | ||||
import numpy as np | import numpy as np | ||||
from scipy.ndimage.filters import convolve | |||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
from mindspore.ops.composite import GradOperation | from mindspore.ops.composite import GradOperation | ||||
from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types | |||||
from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types, check_equal_shape | |||||
from .logger import LogUtil | from .logger import LogUtil | ||||
@@ -61,7 +63,7 @@ def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes) | |||||
Args: | Args: | ||||
grad_wrap_net (Cell): A network wrapped by GradWrap. | grad_wrap_net (Cell): A network wrapped by GradWrap. | ||||
inputs (numpy.ndarray): Input samples. | inputs (numpy.ndarray): Input samples. | ||||
num_boxes (int): Number of boxes infered by each image. | |||||
num_boxes (int): Number of boxes inferred by each image. | |||||
num_classes (int): Number of labels of model output. | num_classes (int): Number of labels of model output. | ||||
Returns: | Returns: | ||||
@@ -251,3 +253,109 @@ def to_tensor_tuple(inputs_ori): | |||||
else: | else: | ||||
inputs_tensor = (Tensor(inputs_ori),) | inputs_tensor = (Tensor(inputs_ori),) | ||||
return inputs_tensor | return inputs_tensor | ||||
def calculate_lp_distance(original_image, compared_image): | |||||
""" | |||||
Calculate l0, l2 and linf distance for two images with the same shape. | |||||
Args: | |||||
original_image (np.ndarray): Original image. | |||||
compared_image (np.ndarray): Another image for comparison. | |||||
Returns: | |||||
tuple, (l0, l2 and linf) distances between two images. | |||||
Raises: | |||||
TypeError: If type of original_image or type of compared_image is not numpy.ndarray. | |||||
ValueError: If the shape of original_image and compared_image are not the same. | |||||
""" | |||||
check_numpy_param('original_image', original_image) | |||||
check_numpy_param('compared_image', compared_image) | |||||
check_equal_shape('original_image', original_image, 'compared_image', compared_image) | |||||
avoid_zero_div = 1e-14 | |||||
diff = (original_image - compared_image).flatten() | |||||
data = original_image.flatten() | |||||
l0_dist = np.linalg.norm(diff, ord=0) \ | |||||
/ (np.linalg.norm(data, ord=0) + avoid_zero_div) | |||||
l2_dist = np.linalg.norm(diff, ord=2) \ | |||||
/ (np.linalg.norm(data, ord=2) + avoid_zero_div) | |||||
linf_dist = np.linalg.norm(diff, ord=np.inf) \ | |||||
/ (np.linalg.norm(data, ord=np.inf) + avoid_zero_div) | |||||
return l0_dist, l2_dist, linf_dist | |||||
def compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11): | |||||
""" | |||||
compute structural similarity between two images. | |||||
Args: | |||||
img_1 (numpy.ndarray): The first image to be compared. The shape of img_1 should be (img_width, img_height, | |||||
channels). | |||||
img_2 (numpy.ndarray): The second image to be compared. The shape of img_2 should be (img_width, img_height, | |||||
channels). | |||||
kernel_sigma (float): Gassian kernel param. Default: 1.5. | |||||
kernel_width (int): Another Gassian kernel param. Default: 11. | |||||
Returns: | |||||
float, structural similarity. | |||||
""" | |||||
img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2) | |||||
if len(img_1.shape) > 2: | |||||
if (len(img_1.shape) != 3) or (img_1.shape[2] != 1 and img_1.shape[2] != 3): | |||||
msg = 'The shape format of img_1 and img_2 should be (img_width, img_height, channels),' \ | |||||
' but got {} and {}'.format(img_1.shape, img_2.shape) | |||||
raise ValueError(msg) | |||||
if len(img_1.shape) > 2: | |||||
total_ssim = 0 | |||||
for i in range(img_1.shape[2]): | |||||
total_ssim += compute_ssim(img_1[:, :, i], img_2[:, :, i]) | |||||
return total_ssim / 3 | |||||
# Create gaussian kernel | |||||
gaussian_kernel = np.zeros((kernel_width, kernel_width)) | |||||
for i in range(kernel_width): | |||||
for j in range(kernel_width): | |||||
gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp( | |||||
- (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2))) | |||||
img_1 = img_1.astype(np.float32) | |||||
img_2 = img_2.astype(np.float32) | |||||
img_sq_1 = img_1**2 | |||||
img_sq_2 = img_2**2 | |||||
img_12 = img_1*img_2 | |||||
# Mean | |||||
img_mu_1 = convolve(img_1, gaussian_kernel) | |||||
img_mu_2 = convolve(img_2, gaussian_kernel) | |||||
# Mean square | |||||
img_mu_sq_1 = img_mu_1**2 | |||||
img_mu_sq_2 = img_mu_2**2 | |||||
img_mu_12 = img_mu_1*img_mu_2 | |||||
# Variances | |||||
img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel) | |||||
img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel) | |||||
# Covariance | |||||
img_sigma_12 = convolve(img_12, gaussian_kernel) | |||||
# Centered squares of variances | |||||
img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1 | |||||
img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2 | |||||
img_sigma_12 = img_sigma_12 - img_mu_12 | |||||
k_1 = 0.01 | |||||
k_2 = 0.03 | |||||
c_1 = (k_1*255)**2 | |||||
c_2 = (k_2*255)**2 | |||||
# Calculate ssim | |||||
num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2) | |||||
den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1 | |||||
+ img_sigma_sq_2 + c_2) | |||||
res = np.average(num_ssim / den_ssim) | |||||
return res |
@@ -0,0 +1,41 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Inversion attack test | |||||
""" | |||||
import pytest | |||||
import numpy as np | |||||
import mindspore.context as context | |||||
from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack | |||||
from ut.python.utils.mock_net import Net | |||||
context.set_context(mode=context.GRAPH_MODE) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_inversion_attack(): | |||||
net = Net() | |||||
target_features = np.random.random((2, 10)).astype(np.float32) | |||||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||||
inversion_images = inversion_attack.generate(target_features, iters=10) | |||||
assert target_features.shape[0] == inversion_images.shape[0] |