Browse Source

modify the grad clipping operation

tags/v1.9.0
huangjiaqi 3 years ago
parent
commit
2b8bdd4a2e
4 changed files with 53 additions and 41 deletions
  1. +2
    -0
      docs/api/api_python/mindarmour.reliability.rst
  2. +3
    -2
      examples/privacy/diff_privacy/lenet5_config.py
  3. +1
    -0
      examples/privacy/diff_privacy/lenet5_dp.py
  4. +47
    -39
      mindarmour/privacy/diff_privacy/train/model.py

+ 2
- 0
docs/api/api_python/mindarmour.reliability.rst View File

@@ -108,6 +108,8 @@ MindArmour的可靠性方法。


.. py:method:: get_optimal_threshold(label, ds_eval) .. py:method:: get_optimal_threshold(label, ds_eval)


获取最佳阈值。尝试找到一个最佳阈值来检测OOD样本。最佳阈值由标记的数据集 `ds_eval` 计算。

参数: 参数:
- **label** (numpy.ndarray) - 区分图像是否为分布内或分布外的标签。 - **label** (numpy.ndarray) - 区分图像是否为分布内或分布外的标签。
- **ds_eval** (numpy.ndarray) - 帮助查找阈值的测试数据集。 - **ds_eval** (numpy.ndarray) - 帮助查找阈值的测试数据集。


+ 3
- 2
examples/privacy/diff_privacy/lenet5_config.py View File

@@ -33,7 +33,7 @@ mnist_cfg = edict({
'dataset_sink_mode': False, # whether deliver all training data to device one time 'dataset_sink_mode': False, # whether deliver all training data to device one time
'micro_batches': 32, # the number of small batches split from an original batch 'micro_batches': 32, # the number of small batches split from an original batch
'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters 'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters
'initial_noise_multiplier': 0.05, # the initial multiplication coefficient of the noise added to training
'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training
# parameters' gradients # parameters' gradients
'noise_mechanisms': 'Gaussian', # the method of adding noise in gradients while training 'noise_mechanisms': 'Gaussian', # the method of adding noise in gradients while training
'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training 'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training
@@ -41,5 +41,6 @@ mnist_cfg = edict({
'clip_learning_rate': 0.001, # Learning rate of update norm clip. 'clip_learning_rate': 0.001, # Learning rate of update norm clip.
'target_unclipped_quantile': 0.9, # Target quantile of norm clip. 'target_unclipped_quantile': 0.9, # Target quantile of norm clip.
'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction. 'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction.
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training
'optimizer': 'Momentum', # the base optimizer used for Differential privacy training
'target_delta': 1e-5 # the target delta budget for DP training
}) })

+ 1
- 0
examples/privacy/diff_privacy/lenet5_dp.py View File

@@ -134,6 +134,7 @@ if __name__ == "__main__":
batch_size=cfg.batch_size, batch_size=cfg.batch_size,
initial_noise_multiplier=cfg.initial_noise_multiplier, initial_noise_multiplier=cfg.initial_noise_multiplier,
per_print_times=234, per_print_times=234,
target_delta=cfg.target_delta,
noise_decay_mode=None) noise_decay_mode=None)
# Create the DP model for training. # Create the DP model for training.
model = DPModel(micro_batches=cfg.micro_batches, model = DPModel(micro_batches=cfg.micro_batches,


+ 47
- 39
mindarmour/privacy/diff_privacy/train/model.py View File

@@ -54,7 +54,6 @@ from ..mechanisms.mechanisms import _MechanismsParamsUpdater
LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
TAG = 'DP model' TAG = 'DP model'


GRADIENT_CLIP_TYPE = 1
_grad_scale = C.MultitypeFuncGraph("grad_scale") _grad_scale = C.MultitypeFuncGraph("grad_scale")
_reciprocal = P.Reciprocal() _reciprocal = P.Reciprocal()


@@ -76,8 +75,8 @@ class DPModel(Model):
Args: Args:
micro_batches (int): The number of small batches split from an original micro_batches (int): The number of small batches split from an original
batch. Default: 2. batch. Default: 2.
norm_bound (float): Use to clip the bound, if set 1, will return the
original data. Default: 1.0.
norm_bound (float): The norm bound that is used to clip the gradient of
each sample. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type of noise_mech (Mechanisms): The object can generate the different type of
noise. Default: None. noise. Default: None.
clip_mech (Mechanisms): The object is used to update the adaptive clip. clip_mech (Mechanisms): The object is used to update the adaptive clip.
@@ -275,9 +274,10 @@ class _ClipGradients(nn.Cell):
Clip gradients. Clip gradients.


Inputs: Inputs:
grads (tuple[Tensor]): Gradients.
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grads (tuple[Tensor]): Gradients to clip.
clip_norm (float): The l2-norm bound used to clip the gradients.
cur_norm (float): The l2-norm of grads. If None, the norm will be
calculated in this function. Default: None.


Outputs: Outputs:
tuple[Tensor], clipped gradients. tuple[Tensor], clipped gradients.
@@ -285,24 +285,29 @@ class _ClipGradients(nn.Cell):


def __init__(self): def __init__(self):
super(_ClipGradients, self).__init__() super(_ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.dtype = P.DType()
self._add = P.Add()
self._reduce_sum = P.ReduceSum()
self._square_all = P.Square()
self._sqrt = P.Sqrt()


def construct(self, grads, clip_type, clip_value):
def construct(self, grads, clip_norm, cur_norm=None):
""" """
construct a compute flow. construct a compute flow.
""" """
if clip_type not in (0, 1):
if cur_norm is None:
# calculate current l2-norm of grads
square_sum = Tensor(0, mstype.float32)
for grad in grads:
square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad)))
cur_norm = self._sqrt(square_sum)

if cur_norm <= clip_norm:
return grads return grads


new_grads = () new_grads = ()
for grad in grads: for grad in grads:
if clip_type == 0:
norm = C.clip_by_value(grad, -clip_value, clip_value)
else:
norm = self.clip_by_norm(grad, clip_value)
new_grads = new_grads + (norm,)

clipped_grad = grad * (clip_norm / cur_norm)
new_grads = new_grads + (clipped_grad,)
return new_grads return new_grads




@@ -339,8 +344,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
Default: None. Default: None.
micro_batches (int): The number of small batches split from an original micro_batches (int): The number of small batches split from an original
batch. Default: None. batch. Default: None.
norm_bound (Tensor): Use to clip the bound, if set 1, will return the
original data. Default: 1.0.
norm_bound (Tensor): The norm bound that is used to clip the gradient of
each sample. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type of noise_mech (Mechanisms): The object can generate the different type of
noise. Default: None. noise. Default: None.


@@ -466,8 +471,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
beta = self._add(beta, beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_bound), self._cast(self._less(norm_grad, self._norm_bound),
mstype.float32)) mstype.float32))
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE,
self._norm_bound)
record_grad = self._clip_by_global_norm(record_grad,
self._norm_bound, norm_grad)
grads = record_grad grads = record_grad
total_loss = loss total_loss = loss
for i in range(1, self._micro_batches): for i in range(1, self._micro_batches):
@@ -488,8 +493,7 @@ class _TrainOneStepWithLossScaleCell(Cell):
mstype.float32)) mstype.float32))


record_grad = self._clip_by_global_norm(record_grad, record_grad = self._clip_by_global_norm(record_grad,
GRADIENT_CLIP_TYPE,
self._norm_bound)
self._norm_bound, norm_grad)
grads = self._tuple_add(grads, record_grad) grads = self._tuple_add(grads, record_grad)
total_loss = P.Add()(total_loss, loss) total_loss = P.Add()(total_loss, loss)
loss = P.Div()(total_loss, self._micro_float) loss = P.Div()(total_loss, self._micro_float)
@@ -560,8 +564,8 @@ class _TrainOneStepCell(Cell):
propagation. Default value is 1.0. propagation. Default value is 1.0.
micro_batches (int): The number of small batches split from an original micro_batches (int): The number of small batches split from an original
batch. Default: None. batch. Default: None.
norm_bound (Tensor): Use to clip the bound, if set 1, will return the
original data. Default: 1.0.
norm_bound (Tensor): The norm bound that is used to clip the gradient of
each sample. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type noise_mech (Mechanisms): The object can generate the different type
of noise. Default: None. of noise. Default: None.
clip_mech (Mechanisms): The object is used to update the adaptive clip. clip_mech (Mechanisms): The object is used to update the adaptive clip.
@@ -644,20 +648,22 @@ class _TrainOneStepCell(Cell):
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
record_grad = self.grad(self.network, weights)(record_datas[0], record_grad = self.grad(self.network, weights)(record_datas[0],
record_labels[0], sens) record_labels[0], sens)
beta = self._zero

# calcu norm_grad
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)

# calcu beta # calcu beta
beta = self._zero
if self._clip_mech is not None: if self._clip_mech is not None:
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum,
self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)
beta = self._add(beta, beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_bound), self._cast(self._less(norm_grad, self._norm_bound),
mstype.float32)) mstype.float32))


record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE,
self._norm_bound)
record_grad = self._clip_by_global_norm(record_grad,
self._norm_bound, norm_grad)
grads = record_grad grads = record_grad
total_loss = loss total_loss = loss
for i in range(1, self._micro_batches): for i in range(1, self._micro_batches):
@@ -666,20 +672,22 @@ class _TrainOneStepCell(Cell):
record_grad = self.grad(self.network, weights)(record_datas[i], record_grad = self.grad(self.network, weights)(record_datas[i],
record_labels[i], record_labels[i],
sens) sens)

# calcu norm_grad
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum,
self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)

# calcu beta # calcu beta
if self._clip_mech is not None: if self._clip_mech is not None:
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum,
self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)
beta = self._add(beta, beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_bound), self._cast(self._less(norm_grad, self._norm_bound),
mstype.float32)) mstype.float32))


record_grad = self._clip_by_global_norm(record_grad, record_grad = self._clip_by_global_norm(record_grad,
GRADIENT_CLIP_TYPE,
self._norm_bound)
self._norm_bound, norm_grad)
grads = self._tuple_add(grads, record_grad) grads = self._tuple_add(grads, record_grad)
total_loss = P.Add()(total_loss, loss) total_loss = P.Add()(total_loss, loss)
loss = self._div(total_loss, self._micro_float) loss = self._div(total_loss, self._micro_float)


Loading…
Cancel
Save