diff --git a/mindarmour/privacy/evaluation/inversion_attack.py b/mindarmour/privacy/evaluation/inversion_attack.py index fe06abb..757b339 100644 --- a/mindarmour/privacy/evaluation/inversion_attack.py +++ b/mindarmour/privacy/evaluation/inversion_attack.py @@ -65,7 +65,7 @@ class InversionLoss(Cell): Tensor, inversion attack loss of the current iteration. """ output = self._network(input_data) - loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) + loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, self._zeros(target_features)) data_shape = self._get_shape(input_data) if self._device_target == 'CPU': @@ -85,7 +85,7 @@ class InversionLoss(Cell): data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:] loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2) - loss_3 = self._mse_loss(input_data, 0) + loss_3 = self._mse_loss(input_data, self._zeros(input_data)) loss = loss_1*self._weights[0] + loss_2*self._weights[1] + loss_3*self._weights[2] return loss