| @@ -42,7 +42,7 @@ python train.py --data_path home_path_to_cifar100 --ckpt_path ./ | |||
| python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt | |||
| ``` | |||
| ## 4. suppress privacy training | |||
| ## 4. Suppress privacy training | |||
| With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected | |||
| layers) are set to zero as the training step grows, which can | |||
| @@ -52,3 +52,15 @@ With suppress privacy mechanism, the values of some trainable parameters (such | |||
| cd examples/privacy/sup_privacy | |||
| python sup_privacy.py | |||
| ``` | |||
| ## 5. Image inversion attack | |||
| Inversion attack means reconstructing an image based on its deep representations. For example, | |||
| reconstruct a MNIST image based on its output through LeNet5. The mechanism behind it is that well-trained | |||
| model can "remember" those training dataset. Therefore, inversion attack can be used to estimate the privacy | |||
| leakage of training tasks. | |||
| ```sh | |||
| cd examples/privacy/inversion_attack | |||
| python mnist_inversion_attack.py | |||
| ``` | |||
| @@ -0,0 +1,108 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Examples of image inversion attack | |||
| """ | |||
| import numpy as np | |||
| import matplotlib.pyplot as plt | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore import Tensor, context | |||
| from mindspore import nn | |||
| from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack | |||
| from mindarmour.utils.logger import LogUtil | |||
| from examples.common.networks.lenet5.lenet5_net import LeNet5, conv, fc_with_initialize | |||
| from examples.common.dataset.data_processing import generate_mnist_dataset | |||
| LOGGER = LogUtil.get_instance() | |||
| LOGGER.set_level('INFO') | |||
| TAG = 'InversionAttack' | |||
| # pylint: disable=invalid-name | |||
| class LeNet5_part(nn.Cell): | |||
| """ | |||
| Part of LeNet5 network. | |||
| """ | |||
| def __init__(self): | |||
| super(LeNet5_part, self).__init__() | |||
| self.conv1 = conv(1, 6, 5) | |||
| self.conv2 = conv(6, 16, 5) | |||
| self.fc1 = fc_with_initialize(16*5*5, 120) | |||
| self.fc2 = fc_with_initialize(120, 84) | |||
| self.fc3 = fc_with_initialize(84, 10) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flatten = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| return x | |||
| def mnist_inversion_attack(net): | |||
| """ | |||
| Image inversion attack based on LeNet5 and MNIST dataset. | |||
| """ | |||
| # upload trained network | |||
| ckpt_path = '../../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| load_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, load_dict) | |||
| # get test data | |||
| data_list = "../../common/dataset/MNIST/test" | |||
| batch_size = 32 | |||
| ds = generate_mnist_dataset(data_list, batch_size) | |||
| inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||
| i = 0 | |||
| batch_num = 1 | |||
| sample_num = 10 | |||
| for data in ds.create_tuple_iterator(output_numpy=True): | |||
| i += 1 | |||
| images = data[0].astype(np.float32) | |||
| target_features = net(Tensor(images)).asnumpy() | |||
| original_images = images[: sample_num] | |||
| inversion_images = inversion_attack.generate(target_features[:sample_num], iters=100) | |||
| for n in range(1, sample_num+1): | |||
| plt.subplot(2, sample_num, n) | |||
| plt.gray() | |||
| plt.imshow(images[n - 1].reshape(32, 32)) | |||
| plt.subplot(2, sample_num, n + sample_num) | |||
| plt.gray() | |||
| plt.imshow(inversion_images[n - 1].reshape(32, 32)) | |||
| plt.show() | |||
| if i >= batch_num: | |||
| break | |||
| # evaluate the similarity between inversion images and original images | |||
| avg_l2_dis, avg_ssim = inversion_attack.evaluate(original_images, inversion_images) | |||
| LOGGER.info(TAG, 'The average L2 distance between original images and inversion images is: {}'.format(avg_l2_dis)) | |||
| LOGGER.info(TAG, 'The average ssim value between original images and inversion images is: {}'.format(avg_ssim)) | |||
| if __name__ == '__main__': | |||
| # device_target can be "CPU", "GPU" or "Ascend" | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| # attack based on complete LeNet5 | |||
| mnist_inversion_attack(LeNet5()) | |||
| # attack based on part of LeNet5. The network is more shallower and can lead to a better attack result | |||
| mnist_inversion_attack(LeNet5_part()) | |||
| @@ -9,6 +9,10 @@ from .adv_robustness.detectors.detector import Detector | |||
| from .fuzz_testing.fuzzing import Fuzzer | |||
| from .privacy.diff_privacy import DPModel | |||
| from .privacy.evaluation.membership_inference import MembershipInference | |||
| from .privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl | |||
| from .privacy.sup_privacy.train.model import SuppressModel | |||
| from .privacy.sup_privacy.mask_monitor.masker import SuppressMasker | |||
| from .privacy.evaluation.inversion_attack import ImageInversionAttack | |||
| __all__ = ['Attack', | |||
| 'BlackModel', | |||
| @@ -16,4 +20,8 @@ __all__ = ['Attack', | |||
| 'Defense', | |||
| 'Fuzzer', | |||
| 'DPModel', | |||
| 'MembershipInference'] | |||
| 'MembershipInference', | |||
| 'SuppressModel', | |||
| 'SuppressCtrl', | |||
| 'SuppressMasker', | |||
| 'ImageInversionAttack'] | |||
| @@ -17,84 +17,15 @@ Attack evaluation. | |||
| import numpy as np | |||
| from scipy.ndimage.filters import convolve | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.utils._check_param import check_pair_numpy_param, \ | |||
| check_param_type, check_numpy_param, check_equal_shape | |||
| check_param_type, check_numpy_param | |||
| from mindarmour.utils.util import calculate_lp_distance, compute_ssim | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'AttackEvaluate' | |||
| def _compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11): | |||
| """ | |||
| compute structural similarity. | |||
| Args: | |||
| img_1 (numpy.ndarray): The first image to be compared. | |||
| img_2 (numpy.ndarray): The second image to be compared. | |||
| kernel_sigma (float): Gassian kernel param. Default: 1.5. | |||
| kernel_width (int): Another Gassian kernel param. Default: 11. | |||
| Returns: | |||
| float, structural similarity. | |||
| """ | |||
| img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2) | |||
| if len(img_1.shape) > 2: | |||
| total_ssim = 0 | |||
| for i in range(img_1.shape[2]): | |||
| total_ssim += _compute_ssim(img_1[:, :, i], img_2[:, :, i]) | |||
| return total_ssim / 3 | |||
| # Create gaussian kernel | |||
| gaussian_kernel = np.zeros((kernel_width, kernel_width)) | |||
| for i in range(kernel_width): | |||
| for j in range(kernel_width): | |||
| gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp( | |||
| - (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2))) | |||
| img_1 = img_1.astype(np.float32) | |||
| img_2 = img_2.astype(np.float32) | |||
| img_sq_1 = img_1**2 | |||
| img_sq_2 = img_2**2 | |||
| img_12 = img_1*img_2 | |||
| # Mean | |||
| img_mu_1 = convolve(img_1, gaussian_kernel) | |||
| img_mu_2 = convolve(img_2, gaussian_kernel) | |||
| # Mean square | |||
| img_mu_sq_1 = img_mu_1**2 | |||
| img_mu_sq_2 = img_mu_2**2 | |||
| img_mu_12 = img_mu_1*img_mu_2 | |||
| # Variances | |||
| img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel) | |||
| img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel) | |||
| # Covariance | |||
| img_sigma_12 = convolve(img_12, gaussian_kernel) | |||
| # Centered squares of variances | |||
| img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1 | |||
| img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2 | |||
| img_sigma_12 = img_sigma_12 - img_mu_12 | |||
| k_1 = 0.01 | |||
| k_2 = 0.03 | |||
| c_1 = (k_1*255)**2 | |||
| c_2 = (k_2*255)**2 | |||
| # Calculate ssim | |||
| num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2) | |||
| den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1 | |||
| + img_sigma_sq_2 + c_2) | |||
| res = np.average(num_ssim / den_ssim) | |||
| return res | |||
| class AttackEvaluate: | |||
| """ | |||
| Evaluation metrics of attack methods. | |||
| @@ -217,16 +148,11 @@ class AttackEvaluate: | |||
| l0_dist = 0 | |||
| l2_dist = 0 | |||
| linf_dist = 0 | |||
| avoid_zero_div = 1e-14 | |||
| for i in idxes: | |||
| diff = (self._adv_inputs[i] - self._inputs[i]).flatten() | |||
| data = self._inputs[i].flatten() | |||
| l0_dist += np.linalg.norm(diff, ord=0) \ | |||
| / (np.linalg.norm(data, ord=0) + avoid_zero_div) | |||
| l2_dist += np.linalg.norm(diff, ord=2) \ | |||
| / (np.linalg.norm(data, ord=2) + avoid_zero_div) | |||
| linf_dist += np.linalg.norm(diff, ord=np.inf) \ | |||
| / (np.linalg.norm(data, ord=np.inf) + avoid_zero_div) | |||
| l0_dist_i, l2_dist_i, linf_dist_i = calculate_lp_distance(self._inputs[i], self._adv_inputs[i]) | |||
| l0_dist += l0_dist_i | |||
| l2_dist += l2_dist_i | |||
| linf_dist += linf_dist_i | |||
| return l0_dist / success_num, l2_dist / success_num, \ | |||
| linf_dist / success_num | |||
| @@ -249,7 +175,7 @@ class AttackEvaluate: | |||
| total_ssim = 0.0 | |||
| for _, i in enumerate(self._success_idxes): | |||
| total_ssim += _compute_ssim(self._adv_inputs[i], self._inputs[i]) | |||
| total_ssim += compute_ssim(self._adv_inputs[i], self._inputs[i]) | |||
| return total_ssim / success_num | |||
| @@ -0,0 +1,210 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Inversion Attack | |||
| """ | |||
| import numpy as np | |||
| from mindspore.nn import Cell, MSELoss | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindarmour.utils.util import GradWrapWithLoss | |||
| from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \ | |||
| check_int_positive, check_numpy_param, check_value_positive, check_equal_shape | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.utils.util import calculate_lp_distance, compute_ssim | |||
| LOGGER = LogUtil.get_instance() | |||
| LOGGER.set_level('INFO') | |||
| TAG = 'Image inversion attack' | |||
| class InversionLoss(Cell): | |||
| """ | |||
| The loss function for inversion attack. | |||
| Args: | |||
| network (Cell): The network used to infer images' deep representations. | |||
| weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to | |||
| obtain better results. | |||
| """ | |||
| def __init__(self, network, weights): | |||
| super(InversionLoss, self).__init__() | |||
| self._network = check_param_type('network', network, Cell) | |||
| self._mse_loss = MSELoss() | |||
| self._weights = check_param_multi_types('weights', weights, [list, tuple]) | |||
| self._get_shape = P.Shape() | |||
| def construct(self, input_data, target_features): | |||
| """ | |||
| Calculate the inversion attack loss, which consists of three parts. Loss_1 is for evaluating the difference | |||
| between the target deep representations and current representations; Loss_2 is for evaluating the continuity | |||
| between adjacent pixels; Loss_3 is for regularization. | |||
| Args: | |||
| input_data (Tensor): The reconstructed image during inversion attack. | |||
| target_features (Tensor): Deep representations of the original image. | |||
| Returns: | |||
| Tensor, inversion attack loss of the current iteration. | |||
| """ | |||
| output = self._network(input_data) | |||
| loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) | |||
| data_shape = self._get_shape(input_data) | |||
| split_op_1 = P.Split(2, data_shape[2]) | |||
| split_op_2 = P.Split(3, data_shape[3]) | |||
| data_split_1 = split_op_1(input_data) | |||
| data_split_2 = split_op_2(input_data) | |||
| loss_2 = 0 | |||
| for i in range(1, data_shape[2]): | |||
| loss_2 += self._mse_loss(data_split_1[i], data_split_1[i-1]) | |||
| for j in range(1, data_shape[3]): | |||
| loss_2 += self._mse_loss(data_split_2[j], data_split_2[j-1]) | |||
| loss_3 = self._mse_loss(input_data, 0) | |||
| loss = loss_1*self._weights[0] + loss_2*self._weights[1] + loss_3*self._weights[2] | |||
| return loss | |||
| class ImageInversionAttack: | |||
| """ | |||
| An attack method used to reconstruct images by inverting their deep representations. | |||
| References: `Aravindh Mahendran, Andrea Vedaldi. Understanding Deep Image Representations by Inverting Them. | |||
| 2014. <https://arxiv.org/pdf/1412.0035.pdf>`_ | |||
| Args: | |||
| network (Cell): The network used to infer images' deep representations. | |||
| input_shape (tuple): Data shape of single network input, which should be in accordance with the given | |||
| network. The format of shape should be (channel, image_width, image_height). | |||
| input_bound (Union[tuple, list]): The pixel range of original images, which should be like [minimum_pixel, | |||
| maximum_pixel] or (minimum_pixel, maximum_pixel). | |||
| loss_weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to | |||
| obtain better results. Default: (1, 0.2, 5). | |||
| Raises: | |||
| TypeError: If the type of network is not Cell. | |||
| ValueError: If any value of input_shape is not positive int. | |||
| ValueError: If any value of loss_weights is not positive value. | |||
| """ | |||
| def __init__(self, network, input_shape, input_bound, loss_weights=(1, 0.2, 5)): | |||
| self._network = check_param_type('network', network, Cell) | |||
| for sub_loss_weight in loss_weights: | |||
| check_value_positive('sub_loss_weight', sub_loss_weight) | |||
| self._loss = InversionLoss(self._network, loss_weights) | |||
| self._input_shape = check_param_multi_types('input_shape', input_shape, [list, tuple]) | |||
| for shape_dim in input_shape: | |||
| check_int_positive('shape_dim', shape_dim) | |||
| self._input_bound = check_param_multi_types('input_bound', input_bound, [list, tuple]) | |||
| def generate(self, target_features, iters=100): | |||
| """ | |||
| Reconstruct images based on target_features. | |||
| Args: | |||
| target_features (numpy.ndarray): Deep representations of original images. The first dimension of | |||
| target_features should be img_num. It should be noted that the shape of target_features should be | |||
| (1, dim2, dim3, ...) if img_num equals 1. | |||
| iters (int): iteration times of inversion attack, which should be positive integers. Default: 100. | |||
| Returns: | |||
| numpy.ndarray, reconstructed images, which are expected to be similar to original images. | |||
| Raises: | |||
| TypeError: If the type of target_features is not numpy.ndarray. | |||
| ValueError: If any value of iters is not positive int.Z | |||
| Examples: | |||
| >>> net = LeNet5() | |||
| >>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), | |||
| >>> loss_weights=[1, 0.2, 5]) | |||
| >>> features = np.random.random((2, 10)).astype(np.float32) | |||
| >>> images = inversion_attack.generate(features, iters=10) | |||
| >>> print(images.shape) | |||
| (2, 1, 32, 32) | |||
| """ | |||
| target_features = check_numpy_param('target_features', target_features) | |||
| iters = check_int_positive('iters', iters) | |||
| # shape checking | |||
| img_num = target_features.shape[0] | |||
| test_input = np.random.random((img_num,) + self._input_shape).astype(np.float32) | |||
| test_out = self._network(Tensor(test_input)).asnumpy() | |||
| if test_out.shape != target_features.shape: | |||
| msg = "The shape of target_features ({}) is not in accordance with the shape" \ | |||
| " of network output({})".format(target_features.shape, test_out.shape) | |||
| raise ValueError(msg) | |||
| loss_net = self._loss | |||
| loss_grad = GradWrapWithLoss(loss_net) | |||
| inversion_images = [] | |||
| for i in range(img_num): | |||
| target_feature_n = target_features[i] | |||
| inversion_image_n = np.random.random((1,) + self._input_shape).astype(np.float32)*0.05 | |||
| for s in range(iters): | |||
| x_grad = loss_grad(Tensor(inversion_image_n), Tensor(target_feature_n)).asnumpy() | |||
| x_grad_sign = np.sign(x_grad) | |||
| inversion_image_n -= x_grad_sign*0.01 | |||
| inversion_image_n = np.clip(inversion_image_n, self._input_bound[0], self._input_bound[1]) | |||
| current_loss = self._loss(Tensor(inversion_image_n), Tensor(target_feature_n)) | |||
| LOGGER.info(TAG, 'iteration step: {}, loss is {}'.format(s, current_loss)) | |||
| inversion_images.append(inversion_image_n) | |||
| return np.concatenate(np.array(inversion_images)) | |||
| def evaluate(self, original_images, inversion_images): | |||
| """ | |||
| Compute the average L2 distance and SSIM value between original images and inversion images. | |||
| Args: | |||
| original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width, | |||
| img_height). | |||
| inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width, | |||
| img_height). | |||
| Returns: | |||
| tuple, the average l2 distance and average ssim value between original images and inversion images. | |||
| Examples: | |||
| >>> net = LeNet5() | |||
| >>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), | |||
| >>> loss_weights=[1, 0.2, 5]) | |||
| >>> features = np.random.random((2, 10)).astype(np.float32) | |||
| >>> inver_images = inversion_attack.generate(features, iters=10) | |||
| >>> ori_images = np.random.random((2, 1, 32, 32)) | |||
| >>> result = inversion_attack.evaluate(ori_images, inver_images) | |||
| >>> print(len(result)) | |||
| 2 | |||
| """ | |||
| check_numpy_param('original_images', original_images) | |||
| check_numpy_param('inversion_images', inversion_images) | |||
| img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images) | |||
| if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3): | |||
| msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \ | |||
| ' but got {} and {}'.format(img_1.shape, img_2.shape) | |||
| raise ValueError(msg) | |||
| total_l2_distance = 0 | |||
| total_ssim = 0 | |||
| img_1 = img_1.transpose(0, 2, 3, 1) | |||
| img_2 = img_2.transpose(0, 2, 3, 1) | |||
| for i in range(img_1.shape[0]): | |||
| _, l2_dis, _ = calculate_lp_distance(img_1[i], img_2[i]) | |||
| total_l2_distance += l2_dis | |||
| total_ssim += compute_ssim(img_1[i], img_2[i]) | |||
| avg_l2_dis = total_l2_distance / img_1.shape[0] | |||
| avg_ssim = total_ssim / img_1.shape[0] | |||
| return avg_l2_dis, avg_ssim | |||
| @@ -61,6 +61,11 @@ def check_param_multi_types(arg_name, arg_value, valid_types): | |||
| def check_int_positive(arg_name, arg_value): | |||
| """Check positive integer.""" | |||
| # 'True' is treated as int(1) in python, which is a bug. | |||
| if isinstance(arg_value, bool): | |||
| msg = '{} should not be bool value, but got {}'.format(arg_name, arg_value) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| arg_value = check_param_type(arg_name, arg_value, int) | |||
| if arg_value <= 0: | |||
| msg = '{} must be greater than 0, but got {}'.format(arg_name, | |||
| @@ -13,11 +13,13 @@ | |||
| # limitations under the License. | |||
| """ Util for MindArmour. """ | |||
| import numpy as np | |||
| from scipy.ndimage.filters import convolve | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Cell | |||
| from mindspore.ops.composite import GradOperation | |||
| from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types | |||
| from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types, check_equal_shape | |||
| from .logger import LogUtil | |||
| @@ -61,7 +63,7 @@ def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes) | |||
| Args: | |||
| grad_wrap_net (Cell): A network wrapped by GradWrap. | |||
| inputs (numpy.ndarray): Input samples. | |||
| num_boxes (int): Number of boxes infered by each image. | |||
| num_boxes (int): Number of boxes inferred by each image. | |||
| num_classes (int): Number of labels of model output. | |||
| Returns: | |||
| @@ -251,3 +253,109 @@ def to_tensor_tuple(inputs_ori): | |||
| else: | |||
| inputs_tensor = (Tensor(inputs_ori),) | |||
| return inputs_tensor | |||
| def calculate_lp_distance(original_image, compared_image): | |||
| """ | |||
| Calculate l0, l2 and linf distance for two images with the same shape. | |||
| Args: | |||
| original_image (np.ndarray): Original image. | |||
| compared_image (np.ndarray): Another image for comparison. | |||
| Returns: | |||
| tuple, (l0, l2 and linf) distances between two images. | |||
| Raises: | |||
| TypeError: If type of original_image or type of compared_image is not numpy.ndarray. | |||
| ValueError: If the shape of original_image and compared_image are not the same. | |||
| """ | |||
| check_numpy_param('original_image', original_image) | |||
| check_numpy_param('compared_image', compared_image) | |||
| check_equal_shape('original_image', original_image, 'compared_image', compared_image) | |||
| avoid_zero_div = 1e-14 | |||
| diff = (original_image - compared_image).flatten() | |||
| data = original_image.flatten() | |||
| l0_dist = np.linalg.norm(diff, ord=0) \ | |||
| / (np.linalg.norm(data, ord=0) + avoid_zero_div) | |||
| l2_dist = np.linalg.norm(diff, ord=2) \ | |||
| / (np.linalg.norm(data, ord=2) + avoid_zero_div) | |||
| linf_dist = np.linalg.norm(diff, ord=np.inf) \ | |||
| / (np.linalg.norm(data, ord=np.inf) + avoid_zero_div) | |||
| return l0_dist, l2_dist, linf_dist | |||
| def compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11): | |||
| """ | |||
| compute structural similarity between two images. | |||
| Args: | |||
| img_1 (numpy.ndarray): The first image to be compared. The shape of img_1 should be (img_width, img_height, | |||
| channels). | |||
| img_2 (numpy.ndarray): The second image to be compared. The shape of img_2 should be (img_width, img_height, | |||
| channels). | |||
| kernel_sigma (float): Gassian kernel param. Default: 1.5. | |||
| kernel_width (int): Another Gassian kernel param. Default: 11. | |||
| Returns: | |||
| float, structural similarity. | |||
| """ | |||
| img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2) | |||
| if len(img_1.shape) > 2: | |||
| if (len(img_1.shape) != 3) or (img_1.shape[2] != 1 and img_1.shape[2] != 3): | |||
| msg = 'The shape format of img_1 and img_2 should be (img_width, img_height, channels),' \ | |||
| ' but got {} and {}'.format(img_1.shape, img_2.shape) | |||
| raise ValueError(msg) | |||
| if len(img_1.shape) > 2: | |||
| total_ssim = 0 | |||
| for i in range(img_1.shape[2]): | |||
| total_ssim += compute_ssim(img_1[:, :, i], img_2[:, :, i]) | |||
| return total_ssim / 3 | |||
| # Create gaussian kernel | |||
| gaussian_kernel = np.zeros((kernel_width, kernel_width)) | |||
| for i in range(kernel_width): | |||
| for j in range(kernel_width): | |||
| gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp( | |||
| - (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2))) | |||
| img_1 = img_1.astype(np.float32) | |||
| img_2 = img_2.astype(np.float32) | |||
| img_sq_1 = img_1**2 | |||
| img_sq_2 = img_2**2 | |||
| img_12 = img_1*img_2 | |||
| # Mean | |||
| img_mu_1 = convolve(img_1, gaussian_kernel) | |||
| img_mu_2 = convolve(img_2, gaussian_kernel) | |||
| # Mean square | |||
| img_mu_sq_1 = img_mu_1**2 | |||
| img_mu_sq_2 = img_mu_2**2 | |||
| img_mu_12 = img_mu_1*img_mu_2 | |||
| # Variances | |||
| img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel) | |||
| img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel) | |||
| # Covariance | |||
| img_sigma_12 = convolve(img_12, gaussian_kernel) | |||
| # Centered squares of variances | |||
| img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1 | |||
| img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2 | |||
| img_sigma_12 = img_sigma_12 - img_mu_12 | |||
| k_1 = 0.01 | |||
| k_2 = 0.03 | |||
| c_1 = (k_1*255)**2 | |||
| c_2 = (k_2*255)**2 | |||
| # Calculate ssim | |||
| num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2) | |||
| den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1 | |||
| + img_sigma_sq_2 + c_2) | |||
| res = np.average(num_ssim / den_ssim) | |||
| return res | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Inversion attack test | |||
| """ | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack | |||
| from ut.python.utils.mock_net import Net | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_inversion_attack(): | |||
| net = Net() | |||
| target_features = np.random.random((2, 10)).astype(np.float32) | |||
| inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||
| inversion_images = inversion_attack.generate(target_features, iters=10) | |||
| assert target_features.shape[0] == inversion_images.shape[0] | |||