|
|
@@ -65,7 +65,7 @@ class InversionLoss(Cell): |
|
|
|
Tensor, inversion attack loss of the current iteration. |
|
|
|
""" |
|
|
|
output = self._network(input_data) |
|
|
|
loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) |
|
|
|
loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, self._zeros(target_features)) |
|
|
|
|
|
|
|
data_shape = self._get_shape(input_data) |
|
|
|
if self._device_target == 'CPU': |
|
|
@@ -85,7 +85,7 @@ class InversionLoss(Cell): |
|
|
|
data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:] |
|
|
|
loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2) |
|
|
|
|
|
|
|
loss_3 = self._mse_loss(input_data, 0) |
|
|
|
loss_3 = self._mse_loss(input_data, self._zeros(input_data)) |
|
|
|
|
|
|
|
loss = loss_1*self._weights[0] + loss_2*self._weights[1] + loss_3*self._weights[2] |
|
|
|
return loss |
|
|
|