From de88ffeb2ae19b150711954fc0114c560dd6c92b Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Thu, 13 May 2021 09:30:17 +0800 Subject: [PATCH] Fix a bug of inversion-attack --- mindarmour/privacy/evaluation/inversion_attack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindarmour/privacy/evaluation/inversion_attack.py b/mindarmour/privacy/evaluation/inversion_attack.py index fe06abb..757b339 100644 --- a/mindarmour/privacy/evaluation/inversion_attack.py +++ b/mindarmour/privacy/evaluation/inversion_attack.py @@ -65,7 +65,7 @@ class InversionLoss(Cell): Tensor, inversion attack loss of the current iteration. """ output = self._network(input_data) - loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) + loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, self._zeros(target_features)) data_shape = self._get_shape(input_data) if self._device_target == 'CPU': @@ -85,7 +85,7 @@ class InversionLoss(Cell): data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:] loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2) - loss_3 = self._mse_loss(input_data, 0) + loss_3 = self._mse_loss(input_data, self._zeros(input_data)) loss = loss_1*self._weights[0] + loss_2*self._weights[1] + loss_3*self._weights[2] return loss