|
|
@@ -12,7 +12,6 @@ |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
""" Monitor module of differential privacy training. """ |
|
|
|
import math |
|
|
|
import numpy as np |
|
|
|
from scipy import special |
|
|
|
|
|
|
@@ -40,8 +39,9 @@ class PrivacyMonitorFactory: |
|
|
|
Create a privacy monitor class. |
|
|
|
|
|
|
|
Args: |
|
|
|
policy (str): Monitor policy, 'rdp' is supported by now. RDP means R'enyi differential privacy, |
|
|
|
which computed based on R'enyi divergence. |
|
|
|
policy (str): Monitor policy, 'rdp' is supported by now. RDP |
|
|
|
means R'enyi differential privacy, which computed based |
|
|
|
on R'enyi divergence. |
|
|
|
args (Union[int, float, numpy.ndarray, list, str]): Parameters |
|
|
|
used for creating a privacy monitor. |
|
|
|
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword |
|
|
@@ -60,9 +60,14 @@ class PrivacyMonitorFactory: |
|
|
|
|
|
|
|
|
|
|
|
class RDPMonitor(Callback): |
|
|
|
""" |
|
|
|
r""" |
|
|
|
Compute the privacy budget of DP training based on Renyi differential |
|
|
|
privacy theory. |
|
|
|
privacy (RDP) theory. According to the reference below, if a randomized |
|
|
|
mechanism is said to have ε'-Renyi differential privacy of order α, it |
|
|
|
also satisfies conventional differential privacy (ε, δ) as below: |
|
|
|
|
|
|
|
.. math:: |
|
|
|
(ε'+\frac{log(1/δ)}{α-1}, δ) |
|
|
|
|
|
|
|
Reference: `Rényi Differential Privacy of the Sampled Gaussian Mechanism |
|
|
|
<https://arxiv.org/abs/1908.10530>`_ |
|
|
@@ -70,33 +75,43 @@ class RDPMonitor(Callback): |
|
|
|
Args: |
|
|
|
num_samples (int): The total number of samples in training data sets. |
|
|
|
batch_size (int): The number of samples in a batch while training. |
|
|
|
initial_noise_multiplier (Union[float, int]): The initial |
|
|
|
multiplier of the noise added to training parameters' gradients. Default: 1.5. |
|
|
|
initial_noise_multiplier (Union[float, int]): Ratio of the standard |
|
|
|
deviation of Gaussian noise divided by the norm_bound, which will |
|
|
|
be used to calculate privacy spent. Default: 1.5. |
|
|
|
max_eps (Union[float, int, None]): The maximum acceptable epsilon |
|
|
|
budget for DP training. Default: 10.0. |
|
|
|
budget for DP training, which is used for estimating the max |
|
|
|
training epochs. Default: 10.0. |
|
|
|
target_delta (Union[float, int, None]): Target delta budget for DP |
|
|
|
training. Default: 1e-3. |
|
|
|
training. If target_delta is set to be δ, then the privacy budget |
|
|
|
δ would be fixed during the whole training process. Default: 1e-3. |
|
|
|
max_delta (Union[float, int, None]): The maximum acceptable delta |
|
|
|
budget for DP training. Max_delta must be less than 1 and |
|
|
|
suggested to be less than 1e-3, otherwise overflow would be |
|
|
|
encountered. Default: None. |
|
|
|
budget for DP training, which is used for estimating the max |
|
|
|
training epochs. Max_delta must be less than 1 and suggested |
|
|
|
to be less than 1e-3, otherwise overflow would be encountered. |
|
|
|
Default: None. |
|
|
|
target_eps (Union[float, int, None]): Target epsilon budget for DP |
|
|
|
training. Default: None. |
|
|
|
training. If target_eps is set to be ε, then the privacy budget |
|
|
|
ε would be fixed during the whole training process. Default: None. |
|
|
|
orders (Union[None, list[int, float]]): Finite orders used for |
|
|
|
computing rdp, which must be greater than 1. |
|
|
|
computing rdp, which must be greater than 1. The computation result |
|
|
|
of privacy budget would be different for various orders. In order |
|
|
|
to obtain a tighter (smaller) privacy budget estimation, a list |
|
|
|
of orders could be tried. Default: None. |
|
|
|
noise_decay_mode (str): Decay mode of adding noise while training, |
|
|
|
which can be 'no_decay', 'Time' or 'Step'. Default: 'Time'. |
|
|
|
noise_decay_rate (Union[float, None]): Decay rate of noise while |
|
|
|
training. Default: 6e-4. |
|
|
|
per_print_times (int): The interval steps of computing and printing |
|
|
|
the privacy budget. Default: 50. |
|
|
|
dataset_sink_mode (bool): If True, all training data would be passed to device(Ascend) at once. If False, |
|
|
|
training data would be passed to device after each step training. Default: False. |
|
|
|
dataset_sink_mode (bool): If True, all training data would be passed |
|
|
|
to device(Ascend) at once. If False, training data would be passed |
|
|
|
to device after each step training. Default: False. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> rdp = PrivacyMonitorFactory.create(policy='rdp', |
|
|
|
>>> num_samples=60000, batch_size=256) |
|
|
|
>>> network = Net() |
|
|
|
>>> epochs = 2 |
|
|
|
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits() |
|
|
|
>>> net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) |
|
|
|
>>> model = Model(network, net_loss, net_opt) |
|
|
@@ -158,6 +173,15 @@ class RDPMonitor(Callback): |
|
|
|
self._noise_decay_rate = noise_decay_rate |
|
|
|
self._rdp = 0 |
|
|
|
self._per_print_times = per_print_times |
|
|
|
if self._target_eps is None and self._target_delta is None: |
|
|
|
msg = 'target eps and target delta cannot both be None' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self._target_eps is not None and self._target_delta is not None: |
|
|
|
msg = 'One of target eps and target delta must be None' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
if dataset_sink_mode: |
|
|
|
self._per_print_times = int(self._num_samples / self._batch_size) |
|
|
|
|
|
|
@@ -178,7 +202,7 @@ class RDPMonitor(Callback): |
|
|
|
while epoch < 10000: |
|
|
|
steps = self._num_samples // self._batch_size |
|
|
|
eps, delta = self._compute_privacy_steps( |
|
|
|
list(np.arange((epoch - 1) * steps, epoch * steps + 1))) |
|
|
|
list(np.arange((epoch - 1)*steps, epoch*steps + 1))) |
|
|
|
if self._max_eps is not None: |
|
|
|
if eps <= self._max_eps: |
|
|
|
epoch += 1 |
|
|
@@ -189,6 +213,7 @@ class RDPMonitor(Callback): |
|
|
|
epoch += 1 |
|
|
|
else: |
|
|
|
break |
|
|
|
# reset the rdp for model training |
|
|
|
self._rdp = 0 |
|
|
|
return epoch |
|
|
|
|
|
|
@@ -233,25 +258,15 @@ class RDPMonitor(Callback): |
|
|
|
Returns: |
|
|
|
float, privacy budget. |
|
|
|
""" |
|
|
|
if self._target_eps is None and self._target_delta is None: |
|
|
|
msg = 'target eps and target delta cannot both be None' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self._target_eps is not None and self._target_delta is not None: |
|
|
|
msg = 'One of target eps and target delta must be None' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self._orders is None: |
|
|
|
self._orders = ( |
|
|
|
[1.005, 1.01, 1.02, 1.08, 1.2, 2, 5, 10, 20, 40, 80]) |
|
|
|
|
|
|
|
sampling_rate = self._batch_size / self._num_samples |
|
|
|
noise_step = self._initial_noise_multiplier |
|
|
|
noise_stddev_step = self._initial_noise_multiplier |
|
|
|
|
|
|
|
if self._noise_decay_mode == 'no_decay': |
|
|
|
self._rdp += self._compute_rdp(sampling_rate, noise_step) * len( |
|
|
|
self._rdp += self._compute_rdp(sampling_rate, noise_stddev_step)*len( |
|
|
|
steps) |
|
|
|
else: |
|
|
|
if self._noise_decay_rate is None: |
|
|
@@ -260,33 +275,33 @@ class RDPMonitor(Callback): |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self._noise_decay_mode == 'Time': |
|
|
|
noise_step = [self._initial_noise_multiplier / ( |
|
|
|
1 + self._noise_decay_rate * step) for step in steps] |
|
|
|
noise_stddev_step = [self._initial_noise_multiplier / ( |
|
|
|
1 + self._noise_decay_rate*step) for step in steps] |
|
|
|
|
|
|
|
elif self._noise_decay_mode == 'Step': |
|
|
|
noise_step = [self._initial_noise_multiplier * ( |
|
|
|
1 - self._noise_decay_rate) ** step for step in steps] |
|
|
|
noise_stddev_step = [self._initial_noise_multiplier*( |
|
|
|
1 - self._noise_decay_rate)**step for step in steps] |
|
|
|
self._rdp += sum( |
|
|
|
[self._compute_rdp(sampling_rate, noise) for noise in |
|
|
|
noise_step]) |
|
|
|
noise_stddev_step]) |
|
|
|
eps, delta = self._compute_privacy_budget(self._rdp) |
|
|
|
|
|
|
|
return eps, delta |
|
|
|
|
|
|
|
def _compute_rdp(self, q, noise): |
|
|
|
def _compute_rdp(self, sample_rate, noise_stddev): |
|
|
|
""" |
|
|
|
Compute rdp according to sampling rate, added noise and Renyi |
|
|
|
divergence orders. |
|
|
|
|
|
|
|
Args: |
|
|
|
q (float): Sampling rate of each batch of samples. |
|
|
|
noise (float): Noise multiplier. |
|
|
|
sample_rate (float): Sampling rate of each batch of samples. |
|
|
|
noise_stddev (float): Noise multiplier. |
|
|
|
|
|
|
|
Returns: |
|
|
|
float or numpy.ndarray, rdp values. |
|
|
|
""" |
|
|
|
rdp = np.array( |
|
|
|
[_compute_rdp_order(q, noise, order) for order in self._orders]) |
|
|
|
[_compute_rdp_with_order(sample_rate, noise_stddev, order) for order in self._orders]) |
|
|
|
return rdp |
|
|
|
|
|
|
|
def _compute_privacy_budget(self, rdp): |
|
|
@@ -317,14 +332,9 @@ class RDPMonitor(Callback): |
|
|
|
""" |
|
|
|
orders = np.atleast_1d(self._orders) |
|
|
|
rdps = np.atleast_1d(rdp) |
|
|
|
if len(orders) != len(rdps): |
|
|
|
msg = 'rdp lists and orders list must have the same length.' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
deltas = np.exp((rdps - self._target_eps) * (orders - 1)) |
|
|
|
min_delta = min(deltas) |
|
|
|
return min(min_delta, 1.) |
|
|
|
deltas = np.exp((rdps - self._target_eps)*(orders - 1)) |
|
|
|
min_delta = np.min(deltas) |
|
|
|
return np.min([min_delta, 1.]) |
|
|
|
|
|
|
|
def _compute_eps(self, rdp): |
|
|
|
""" |
|
|
@@ -338,50 +348,46 @@ class RDPMonitor(Callback): |
|
|
|
""" |
|
|
|
orders = np.atleast_1d(self._orders) |
|
|
|
rdps = np.atleast_1d(rdp) |
|
|
|
if len(orders) != len(rdps): |
|
|
|
msg = 'rdp lists and orders list must have the same length.' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
eps = rdps - math.log(self._target_delta) / (orders - 1) |
|
|
|
return min(eps) |
|
|
|
eps = rdps - np.log(self._target_delta) / (orders - 1) |
|
|
|
return np.min(eps) |
|
|
|
|
|
|
|
|
|
|
|
def _compute_rdp_order(q, sigma, alpha): |
|
|
|
def _compute_rdp_with_order(sample_rate, noise_stddev, order): |
|
|
|
""" |
|
|
|
Compute rdp for each order. |
|
|
|
|
|
|
|
Args: |
|
|
|
q (float): Sampling probability. |
|
|
|
sigma (float): Noise multiplier. |
|
|
|
alpha: The order used for computing rdp. |
|
|
|
sample_rate (float): Sampling probability. |
|
|
|
noise_stddev (float): Noise multiplier. |
|
|
|
order: The order used for computing rdp. |
|
|
|
|
|
|
|
Returns: |
|
|
|
float, rdp value. |
|
|
|
""" |
|
|
|
if float(alpha).is_integer(): |
|
|
|
if float(order).is_integer(): |
|
|
|
log_integrate = -np.inf |
|
|
|
for k in range(alpha + 1): |
|
|
|
term_k = (math.log( |
|
|
|
special.binom(alpha, k)) + k * math.log(q) + ( |
|
|
|
alpha - k) * math.log( |
|
|
|
1 - q)) + (k * k - k) / (2 * (sigma ** 2)) |
|
|
|
for k in range(order + 1): |
|
|
|
term_k = (np.log( |
|
|
|
special.binom(order, k)) + k*np.log(sample_rate) + ( |
|
|
|
order - k)*np.log( |
|
|
|
1 - sample_rate)) + (k*k - k) / (2*(noise_stddev**2)) |
|
|
|
log_integrate = _log_add(log_integrate, term_k) |
|
|
|
return float(log_integrate) / (alpha - 1) |
|
|
|
return float(log_integrate) / (order - 1) |
|
|
|
log_part_0, log_part_1 = -np.inf, -np.inf |
|
|
|
k = 0 |
|
|
|
z0 = sigma ** 2 * math.log(1 / q - 1) + 1 / 2 |
|
|
|
z0 = noise_stddev**2*np.log(1 / sample_rate - 1) + 1 / 2 |
|
|
|
while True: |
|
|
|
bi_coef = special.binom(alpha, k) |
|
|
|
log_coef = math.log(abs(bi_coef)) |
|
|
|
j = alpha - k |
|
|
|
bi_coef = special.binom(order, k) |
|
|
|
log_coef = np.log(abs(bi_coef)) |
|
|
|
j = order - k |
|
|
|
|
|
|
|
term_k_part_0 = log_coef + k * math.log(q) + j * math.log(1 - q) + ( |
|
|
|
k * k - k) / (2 * (sigma ** 2)) + special.log_ndtr( |
|
|
|
(z0 - k) / sigma) |
|
|
|
term_k_part_0 = log_coef + k*np.log(sample_rate) + j*np.log(1 - sample_rate) + ( |
|
|
|
k*k - k) / (2*(noise_stddev**2)) + special.log_ndtr( |
|
|
|
(z0 - k) / noise_stddev) |
|
|
|
|
|
|
|
term_k_part_1 = log_coef + j * math.log(q) + k * math.log(1 - q) + ( |
|
|
|
j * j - j) / (2 * (sigma ** 2)) + special.log_ndtr( |
|
|
|
(j - z0) / sigma) |
|
|
|
term_k_part_1 = log_coef + j*np.log(sample_rate) + k*np.log(1 - sample_rate) + ( |
|
|
|
j*j - j) / (2*(noise_stddev**2)) + special.log_ndtr( |
|
|
|
(j - z0) / noise_stddev) |
|
|
|
|
|
|
|
if bi_coef > 0: |
|
|
|
log_part_0 = _log_add(log_part_0, term_k_part_0) |
|
|
@@ -391,10 +397,10 @@ def _compute_rdp_order(q, sigma, alpha): |
|
|
|
log_part_1 = _log_subtract(log_part_1, term_k_part_1) |
|
|
|
|
|
|
|
k += 1 |
|
|
|
if max(term_k_part_0, term_k_part_1) < -30: |
|
|
|
if np.max([term_k_part_0, term_k_part_1]) < -30: |
|
|
|
break |
|
|
|
|
|
|
|
return _log_add(log_part_0, log_part_1) / (alpha - 1) |
|
|
|
return _log_add(log_part_0, log_part_1) / (order - 1) |
|
|
|
|
|
|
|
|
|
|
|
def _log_add(x, y): |
|
|
@@ -405,7 +411,7 @@ def _log_add(x, y): |
|
|
|
return y |
|
|
|
if y == -np.inf: |
|
|
|
return x |
|
|
|
return max(x, y) + math.log1p(math.exp(-abs(x - y))) |
|
|
|
return np.max([x, y]) + np.log1p(np.exp(-abs(x - y))) |
|
|
|
|
|
|
|
|
|
|
|
def _log_subtract(x, y): |
|
|
@@ -418,4 +424,4 @@ def _log_subtract(x, y): |
|
|
|
raise ValueError(msg) |
|
|
|
if y == -np.inf: |
|
|
|
return x |
|
|
|
return math.log1p(math.exp(y - x)) + x |
|
|
|
return np.log1p(np.exp(y - x)) + x |