From: @pkuliuliu Reviewed-by: @jxlang910,@zhidanliu Signed-off-by: @jxlang910tags/v1.1.0
@@ -284,6 +284,21 @@ class NoiseAdaGaussianRandom(NoiseGaussianRandom): | |||||
"get {}".format(decay_policy)) | "get {}".format(decay_policy)) | ||||
self._decay_policy = decay_policy | self._decay_policy = decay_policy | ||||
def construct(self, gradients): | |||||
""" | |||||
Generated Adaptive Gaussian noise. | |||||
Args: | |||||
gradients(Tensor): The gradients. | |||||
Returns: | |||||
Tensor, generated noise with shape like given gradients. | |||||
""" | |||||
shape = P.Shape()(gradients) | |||||
stddev = P.Mul()(self._norm_bound, self._noise_multiplier) | |||||
noise = normal(shape, self._mean, stddev, self._seed) | |||||
return noise | |||||
class _MechanismsParamsUpdater(Cell): | class _MechanismsParamsUpdater(Cell): | ||||
""" | """ | ||||
@@ -515,7 +515,7 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
if self._noise_mech is not None: | if self._noise_mech is not None: | ||||
grad_noise_tuple = () | grad_noise_tuple = () | ||||
for grad_item in grads: | for grad_item in grads: | ||||
grad_noise = self._mech(grad_item) | |||||
grad_noise = self._noise_mech(grad_item) | |||||
grad_noise_tuple = grad_noise_tuple + (grad_noise,) | grad_noise_tuple = grad_noise_tuple + (grad_noise,) | ||||
grads = self._tuple_add(grads, grad_noise_tuple) | grads = self._tuple_add(grads, grad_noise_tuple) | ||||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), | grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), | ||||