Browse Source

add vmap for dp

pull/399/head
zlq2020 3 years ago
parent
commit
f5b01fed6c
8 changed files with 122 additions and 34 deletions
  1. +2
    -1
      examples/privacy/diff_privacy/dp_ada_gaussian_config.py
  2. +2
    -1
      examples/privacy/diff_privacy/dp_ada_sgd_graph_config.py
  3. +2
    -1
      examples/privacy/diff_privacy/lenet5_config.py
  4. +1
    -0
      examples/privacy/diff_privacy/lenet5_dp.py
  5. +1
    -0
      examples/privacy/diff_privacy/lenet5_dp_ada_gaussian.py
  6. +1
    -0
      examples/privacy/diff_privacy/lenet5_dp_ada_sgd_graph.py
  7. +1
    -0
      examples/privacy/diff_privacy/lenet5_dp_optimizer.py
  8. +112
    -31
      mindarmour/privacy/diff_privacy/train/model.py

+ 2
- 1
examples/privacy/diff_privacy/dp_ada_gaussian_config.py View File

@@ -36,5 +36,6 @@ mnist_cfg = edict({
'initial_noise_multiplier': 0.05, # the initial multiplication coefficient of the noise added to training
# parameters' gradients
'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training
'optimizer': 'Momentum', # the base optimizer used for Differential privacy training
'auto_batch': False # whether to use vmap for better performance.
})

+ 2
- 1
examples/privacy/diff_privacy/dp_ada_sgd_graph_config.py View File

@@ -37,5 +37,6 @@ mnist_cfg = edict({
# parameters' gradients
'decay_policy': 'Step',
'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training
'optimizer': 'SGD' # the base optimizer used for Differential privacy training
'optimizer': 'SGD', # the base optimizer used for Differential privacy training
'auto_batch': False # whether to use vmap for better performance.
})

+ 2
- 1
examples/privacy/diff_privacy/lenet5_config.py View File

@@ -42,5 +42,6 @@ mnist_cfg = edict({
'target_unclipped_quantile': 0.9, # Target quantile of norm clip.
'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction.
'optimizer': 'Momentum', # the base optimizer used for Differential privacy training
'target_delta': 1e-5 # the target delta budget for DP training
'target_delta': 1e-5, # the target delta budget for DP training
'auto_batch': False # whether to use vmap for better performance.
})

+ 1
- 0
examples/privacy/diff_privacy/lenet5_dp.py View File

@@ -141,6 +141,7 @@ if __name__ == "__main__":
norm_bound=cfg.norm_bound,
noise_mech=noise_mech,
clip_mech=clip_mech,
auto_batch=cfg.auto_batch,
network=network,
loss_fn=net_loss,
optimizer=net_opt,


+ 1
- 0
examples/privacy/diff_privacy/lenet5_dp_ada_gaussian.py View File

@@ -127,6 +127,7 @@ if __name__ == "__main__":
model = DPModel(micro_batches=cfg.micro_batches,
norm_bound=cfg.norm_bound,
noise_mech=noise_mech,
auto_batch=cfg.auto_batch,
network=network,
loss_fn=net_loss,
optimizer=net_opt,


+ 1
- 0
examples/privacy/diff_privacy/lenet5_dp_ada_sgd_graph.py View File

@@ -127,6 +127,7 @@ if __name__ == "__main__":
model = DPModel(micro_batches=cfg.micro_batches,
norm_bound=cfg.norm_bound,
noise_mech=noise_mech,
auto_batch=cfg.auto_batch,
network=network,
loss_fn=net_loss,
optimizer=net_opt,


+ 1
- 0
examples/privacy/diff_privacy/lenet5_dp_optimizer.py View File

@@ -135,6 +135,7 @@ if __name__ == "__main__":
norm_bound=cfg.norm_bound,
noise_mech=None,
clip_mech=clip_mech,
auto_batch=cfg.auto_batch,
network=network,
loss_fn=net_loss,
optimizer=net_opt,


+ 112
- 31
mindarmour/privacy/diff_privacy/train/model.py View File

@@ -45,6 +45,7 @@ from mindspore.common.parameter import Parameter
from mindspore.nn.wrap.loss_scale import _grad_overflow
from mindspore.nn import Cell
from mindspore import ParameterTuple
from mindspore import numpy as mnp

from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_value_positive, check_param_type
@@ -81,6 +82,8 @@ class DPModel(Model):
noise. Default: None.
clip_mech (Mechanisms): The object is used to update the adaptive clip.
Default: None.
auto_batch(bool): Whether to use vmap for better performance.
Default: False.
optimizer (Cell): Optimizer used for differential privacy training, which can be original mindspore
optimizers (for example, Momentum optimizer) or optimizers generated by DPOptimizerClassFactory.
Default: nn.Momentum.
@@ -94,7 +97,7 @@ class DPModel(Model):
"""

def __init__(self, micro_batches=2, norm_bound=1.0, noise_mech=None,
clip_mech=None, optimizer=nn.Momentum, **kwargs):
clip_mech=None, auto_batch=False, optimizer=nn.Momentum, **kwargs):
if micro_batches:
self._micro_batches = check_int_positive('micro_batches',
micro_batches)
@@ -104,6 +107,7 @@ class DPModel(Model):
norm_bound = check_value_positive('norm_bound', norm_bound)
norm_bound = Tensor(norm_bound, mstype.float32)
self._norm_bound = Parameter(norm_bound, 'norm_bound')
self._auto_batch = auto_batch
if optimizer is None:
msg = 'Optimizer need to be set, but got None.'
LOGGER.error(TAG, msg)
@@ -222,7 +226,8 @@ class DPModel(Model):
loss_scale,
micro_batches=self._micro_batches,
clip_mech=self._clip_mech,
noise_mech=self._noise_mech).set_train()
noise_mech=self._noise_mech,
auto_batch=self._auto_batch).set_train()
return network

def _build_train_network(self):
@@ -301,8 +306,7 @@ class _ClipGradients(nn.Cell):
square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad)))
cur_norm = self._sqrt(square_sum)

if cur_norm <= clip_norm:
return grads
cur_norm = mnp.where((cur_norm <= clip_norm), x=clip_norm, y=cur_norm)

new_grads = ()
for grad in grads:
@@ -570,6 +574,8 @@ class _TrainOneStepCell(Cell):
of noise. Default: None.
clip_mech (Mechanisms): The object is used to update the adaptive clip.
Default: None.
auto_batch(bool): Whether to use vmap for better performance.
Default: False.

Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
@@ -581,7 +587,7 @@ class _TrainOneStepCell(Cell):

def __init__(self, network, optimizer, norm_bound=1.0, sens=1.0,
micro_batches=None,
noise_mech=None, clip_mech=None):
noise_mech=None, clip_mech=None, auto_batch=False):
super(_TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
@@ -614,6 +620,9 @@ class _TrainOneStepCell(Cell):
self._clip_by_global_norm = _ClipGradients()
self._noise_mech = noise_mech
self._clip_mech = clip_mech
self._auto_batch = auto_batch

self._stack = P.Stack()
self._tuple_add = _TupleAdd()
self._add = P.Add()
self._norm = nn.Norm()
@@ -641,21 +650,61 @@ class _TrainOneStepCell(Cell):
"""
construct a compute flow.
"""
weights = self.weights
record_datas = self._split(data)
record_labels = self._split(label)

if self._auto_batch is True:
grads, total_loss, beta = self._calcu_grad_with_vmap(record_datas,
record_labels)
else:
grads, total_loss, beta = self._calcu_grad_with_forloop(record_datas,
record_labels)

loss = self._div(total_loss, self._micro_float)

if self._noise_mech is not None:
grad_noise_tuple = ()
for grad_item in grads:
grad_noise = self._noise_mech(grad_item)
grad_noise_tuple = grad_noise_tuple + (grad_noise,)
grads = self._tuple_add(grads, grad_noise_tuple)
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float),
grads)
# update mech parameters
if self._noise_mech_param_updater is not None:
multiplier = self._noise_mech_param_updater()
loss = F.depend(loss, multiplier)

if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)

if self._clip_mech is not None:
beta = self._div(beta, self._micro_batches)
next_norm_bound = self._clip_mech(beta, self._norm_bound)
self._norm_bound = self._assign(self._norm_bound, next_norm_bound)
loss = F.depend(loss, self._norm_bound)

return F.depend(loss, self.optimizer(grads))


def _calcu_grad_with_forloop(self, record_datas, record_labels):
"""
Calculate per sample gradients with for loop.
"""
weights = self.weights
loss = self.network(record_datas[0], record_labels[0])
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
record_grad = self.grad(self.network, weights)(record_datas[0],
record_labels[0], sens)

# calcu norm_grad
# calculate norm_grad
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)

# calcu beta
# calculate beta
beta = self._zero
if self._clip_mech is not None:
beta = self._add(beta,
@@ -673,14 +722,14 @@ class _TrainOneStepCell(Cell):
record_labels[i],
sens)

# calcu norm_grad
# calculate norm_grad
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum,
self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)

# calcu beta
# calculate beta
if self._clip_mech is not None:
beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_bound),
@@ -690,29 +739,61 @@ class _TrainOneStepCell(Cell):
self._norm_bound, norm_grad)
grads = self._tuple_add(grads, record_grad)
total_loss = P.Add()(total_loss, loss)
loss = self._div(total_loss, self._micro_float)
return grads, total_loss, beta

if self._noise_mech is not None:
grad_noise_tuple = ()
for grad_item in grads:
grad_noise = self._noise_mech(grad_item)
grad_noise_tuple = grad_noise_tuple + (grad_noise,)
grads = self._tuple_add(grads, grad_noise_tuple)
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float),
grads)
# update mech parameters
if self._noise_mech_param_updater is not None:
multiplier = self._noise_mech_param_updater()
loss = F.depend(loss, multiplier)

if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
def _calcu_batch_grad(self, batch_datas, batch_labels):
"""
Calculate batch gradients.
"""
loss = self.network(batch_datas, batch_labels)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
record_grad = self.grad(self.network, self.weights)(batch_datas, batch_labels, sens)

# calculate the norm of the gradient
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)

# calculate beta
beta = self._zero
if self._clip_mech is not None:
beta = self._div(beta, self._micro_batches)
next_norm_bound = self._clip_mech(beta, self._norm_bound)
self._norm_bound = self._assign(self._norm_bound, next_norm_bound)
loss = F.depend(loss, self._norm_bound)
beta = self._cast(self._less(norm_grad, self._norm_bound), mstype.float32)
record_grad = self._clip_by_global_norm(record_grad, self._norm_bound, norm_grad)

return F.depend(loss, self.optimizer(grads))
return record_grad, loss, beta


def _calcu_grad_with_vmap(self, record_datas, record_labels):
"""
Calculate per sample gradients with vmap.
"""
batch_datas = self._stack(record_datas)
batch_labels = self._stack(record_labels)

# parallel calculate with vmap
batch_fn = F.vmap(self._calcu_batch_grad, in_axes=(0, 0))
batch_grads, batch_loss, batch_beta = batch_fn(batch_datas, batch_labels)

grads = _TupleReduceSum(len(batch_grads))(batch_grads)
total_loss = self._reduce_sum(batch_loss)
beta = self._reduce_sum(batch_beta)

return grads, total_loss, beta


class _TupleReduceSum(nn.Cell):
"""
ReduceSum two tuple of data.
"""
def __init__(self, tuplen):
super(_TupleReduceSum, self).__init__()
self.reduce_sum = P.ReduceSum()
self.hyper_map = C.HyperMap()
self.input2 = (0,) * tuplen

def construct(self, input1):
"""use `HyperMap` to calculate tuple"""
out = self.hyper_map(self.reduce_sum, input1, self.input2)
return out

Loading…
Cancel
Save