|
|
@@ -45,6 +45,7 @@ from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.nn.wrap.loss_scale import _grad_overflow |
|
|
|
from mindspore.nn import Cell |
|
|
|
from mindspore import ParameterTuple |
|
|
|
from mindspore import numpy as mnp |
|
|
|
|
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
|
from mindarmour.utils._check_param import check_value_positive, check_param_type |
|
|
@@ -81,6 +82,8 @@ class DPModel(Model): |
|
|
|
noise. Default: None. |
|
|
|
clip_mech (Mechanisms): The object is used to update the adaptive clip. |
|
|
|
Default: None. |
|
|
|
auto_batch(bool): Whether to use vmap for better performance. |
|
|
|
Default: False. |
|
|
|
optimizer (Cell): Optimizer used for differential privacy training, which can be original mindspore |
|
|
|
optimizers (for example, Momentum optimizer) or optimizers generated by DPOptimizerClassFactory. |
|
|
|
Default: nn.Momentum. |
|
|
@@ -94,7 +97,7 @@ class DPModel(Model): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, micro_batches=2, norm_bound=1.0, noise_mech=None, |
|
|
|
clip_mech=None, optimizer=nn.Momentum, **kwargs): |
|
|
|
clip_mech=None, auto_batch=False, optimizer=nn.Momentum, **kwargs): |
|
|
|
if micro_batches: |
|
|
|
self._micro_batches = check_int_positive('micro_batches', |
|
|
|
micro_batches) |
|
|
@@ -104,6 +107,7 @@ class DPModel(Model): |
|
|
|
norm_bound = check_value_positive('norm_bound', norm_bound) |
|
|
|
norm_bound = Tensor(norm_bound, mstype.float32) |
|
|
|
self._norm_bound = Parameter(norm_bound, 'norm_bound') |
|
|
|
self._auto_batch = auto_batch |
|
|
|
if optimizer is None: |
|
|
|
msg = 'Optimizer need to be set, but got None.' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
@@ -222,7 +226,8 @@ class DPModel(Model): |
|
|
|
loss_scale, |
|
|
|
micro_batches=self._micro_batches, |
|
|
|
clip_mech=self._clip_mech, |
|
|
|
noise_mech=self._noise_mech).set_train() |
|
|
|
noise_mech=self._noise_mech, |
|
|
|
auto_batch=self._auto_batch).set_train() |
|
|
|
return network |
|
|
|
|
|
|
|
def _build_train_network(self): |
|
|
@@ -301,8 +306,7 @@ class _ClipGradients(nn.Cell): |
|
|
|
square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad))) |
|
|
|
cur_norm = self._sqrt(square_sum) |
|
|
|
|
|
|
|
if cur_norm <= clip_norm: |
|
|
|
return grads |
|
|
|
cur_norm = mnp.where((cur_norm <= clip_norm), x=clip_norm, y=cur_norm) |
|
|
|
|
|
|
|
new_grads = () |
|
|
|
for grad in grads: |
|
|
@@ -570,6 +574,8 @@ class _TrainOneStepCell(Cell): |
|
|
|
of noise. Default: None. |
|
|
|
clip_mech (Mechanisms): The object is used to update the adaptive clip. |
|
|
|
Default: None. |
|
|
|
auto_batch(bool): Whether to use vmap for better performance. |
|
|
|
Default: False. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. |
|
|
@@ -581,7 +587,7 @@ class _TrainOneStepCell(Cell): |
|
|
|
|
|
|
|
def __init__(self, network, optimizer, norm_bound=1.0, sens=1.0, |
|
|
|
micro_batches=None, |
|
|
|
noise_mech=None, clip_mech=None): |
|
|
|
noise_mech=None, clip_mech=None, auto_batch=False): |
|
|
|
super(_TrainOneStepCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
@@ -614,6 +620,9 @@ class _TrainOneStepCell(Cell): |
|
|
|
self._clip_by_global_norm = _ClipGradients() |
|
|
|
self._noise_mech = noise_mech |
|
|
|
self._clip_mech = clip_mech |
|
|
|
self._auto_batch = auto_batch |
|
|
|
|
|
|
|
self._stack = P.Stack() |
|
|
|
self._tuple_add = _TupleAdd() |
|
|
|
self._add = P.Add() |
|
|
|
self._norm = nn.Norm() |
|
|
@@ -641,21 +650,61 @@ class _TrainOneStepCell(Cell): |
|
|
|
""" |
|
|
|
construct a compute flow. |
|
|
|
""" |
|
|
|
weights = self.weights |
|
|
|
record_datas = self._split(data) |
|
|
|
record_labels = self._split(label) |
|
|
|
|
|
|
|
if self._auto_batch is True: |
|
|
|
grads, total_loss, beta = self._calcu_grad_with_vmap(record_datas, |
|
|
|
record_labels) |
|
|
|
else: |
|
|
|
grads, total_loss, beta = self._calcu_grad_with_forloop(record_datas, |
|
|
|
record_labels) |
|
|
|
|
|
|
|
loss = self._div(total_loss, self._micro_float) |
|
|
|
|
|
|
|
if self._noise_mech is not None: |
|
|
|
grad_noise_tuple = () |
|
|
|
for grad_item in grads: |
|
|
|
grad_noise = self._noise_mech(grad_item) |
|
|
|
grad_noise_tuple = grad_noise_tuple + (grad_noise,) |
|
|
|
grads = self._tuple_add(grads, grad_noise_tuple) |
|
|
|
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), |
|
|
|
grads) |
|
|
|
# update mech parameters |
|
|
|
if self._noise_mech_param_updater is not None: |
|
|
|
multiplier = self._noise_mech_param_updater() |
|
|
|
loss = F.depend(loss, multiplier) |
|
|
|
|
|
|
|
if self.reducer_flag: |
|
|
|
# apply grad reducer on grads |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
|
|
|
|
if self._clip_mech is not None: |
|
|
|
beta = self._div(beta, self._micro_batches) |
|
|
|
next_norm_bound = self._clip_mech(beta, self._norm_bound) |
|
|
|
self._norm_bound = self._assign(self._norm_bound, next_norm_bound) |
|
|
|
loss = F.depend(loss, self._norm_bound) |
|
|
|
|
|
|
|
return F.depend(loss, self.optimizer(grads)) |
|
|
|
|
|
|
|
|
|
|
|
def _calcu_grad_with_forloop(self, record_datas, record_labels): |
|
|
|
""" |
|
|
|
Calculate per sample gradients with for loop. |
|
|
|
""" |
|
|
|
weights = self.weights |
|
|
|
loss = self.network(record_datas[0], record_labels[0]) |
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) |
|
|
|
record_grad = self.grad(self.network, weights)(record_datas[0], |
|
|
|
record_labels[0], sens) |
|
|
|
|
|
|
|
# calcu norm_grad |
|
|
|
# calculate norm_grad |
|
|
|
square_sum = self._zero |
|
|
|
for grad in record_grad: |
|
|
|
square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad))) |
|
|
|
norm_grad = self._sqrt(square_sum) |
|
|
|
|
|
|
|
# calcu beta |
|
|
|
# calculate beta |
|
|
|
beta = self._zero |
|
|
|
if self._clip_mech is not None: |
|
|
|
beta = self._add(beta, |
|
|
@@ -673,14 +722,14 @@ class _TrainOneStepCell(Cell): |
|
|
|
record_labels[i], |
|
|
|
sens) |
|
|
|
|
|
|
|
# calcu norm_grad |
|
|
|
# calculate norm_grad |
|
|
|
square_sum = self._zero |
|
|
|
for grad in record_grad: |
|
|
|
square_sum = self._add(square_sum, |
|
|
|
self._reduce_sum(self._square_all(grad))) |
|
|
|
norm_grad = self._sqrt(square_sum) |
|
|
|
|
|
|
|
# calcu beta |
|
|
|
# calculate beta |
|
|
|
if self._clip_mech is not None: |
|
|
|
beta = self._add(beta, |
|
|
|
self._cast(self._less(norm_grad, self._norm_bound), |
|
|
@@ -690,29 +739,61 @@ class _TrainOneStepCell(Cell): |
|
|
|
self._norm_bound, norm_grad) |
|
|
|
grads = self._tuple_add(grads, record_grad) |
|
|
|
total_loss = P.Add()(total_loss, loss) |
|
|
|
loss = self._div(total_loss, self._micro_float) |
|
|
|
return grads, total_loss, beta |
|
|
|
|
|
|
|
if self._noise_mech is not None: |
|
|
|
grad_noise_tuple = () |
|
|
|
for grad_item in grads: |
|
|
|
grad_noise = self._noise_mech(grad_item) |
|
|
|
grad_noise_tuple = grad_noise_tuple + (grad_noise,) |
|
|
|
grads = self._tuple_add(grads, grad_noise_tuple) |
|
|
|
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), |
|
|
|
grads) |
|
|
|
# update mech parameters |
|
|
|
if self._noise_mech_param_updater is not None: |
|
|
|
multiplier = self._noise_mech_param_updater() |
|
|
|
loss = F.depend(loss, multiplier) |
|
|
|
|
|
|
|
if self.reducer_flag: |
|
|
|
# apply grad reducer on grads |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
def _calcu_batch_grad(self, batch_datas, batch_labels): |
|
|
|
""" |
|
|
|
Calculate batch gradients. |
|
|
|
""" |
|
|
|
loss = self.network(batch_datas, batch_labels) |
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) |
|
|
|
record_grad = self.grad(self.network, self.weights)(batch_datas, batch_labels, sens) |
|
|
|
|
|
|
|
# calculate the norm of the gradient |
|
|
|
square_sum = self._zero |
|
|
|
for grad in record_grad: |
|
|
|
square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad))) |
|
|
|
norm_grad = self._sqrt(square_sum) |
|
|
|
|
|
|
|
# calculate beta |
|
|
|
beta = self._zero |
|
|
|
if self._clip_mech is not None: |
|
|
|
beta = self._div(beta, self._micro_batches) |
|
|
|
next_norm_bound = self._clip_mech(beta, self._norm_bound) |
|
|
|
self._norm_bound = self._assign(self._norm_bound, next_norm_bound) |
|
|
|
loss = F.depend(loss, self._norm_bound) |
|
|
|
beta = self._cast(self._less(norm_grad, self._norm_bound), mstype.float32) |
|
|
|
record_grad = self._clip_by_global_norm(record_grad, self._norm_bound, norm_grad) |
|
|
|
|
|
|
|
return F.depend(loss, self.optimizer(grads)) |
|
|
|
return record_grad, loss, beta |
|
|
|
|
|
|
|
|
|
|
|
def _calcu_grad_with_vmap(self, record_datas, record_labels): |
|
|
|
""" |
|
|
|
Calculate per sample gradients with vmap. |
|
|
|
""" |
|
|
|
batch_datas = self._stack(record_datas) |
|
|
|
batch_labels = self._stack(record_labels) |
|
|
|
|
|
|
|
# parallel calculate with vmap |
|
|
|
batch_fn = F.vmap(self._calcu_batch_grad, in_axes=(0, 0)) |
|
|
|
batch_grads, batch_loss, batch_beta = batch_fn(batch_datas, batch_labels) |
|
|
|
|
|
|
|
grads = _TupleReduceSum(len(batch_grads))(batch_grads) |
|
|
|
total_loss = self._reduce_sum(batch_loss) |
|
|
|
beta = self._reduce_sum(batch_beta) |
|
|
|
|
|
|
|
return grads, total_loss, beta |
|
|
|
|
|
|
|
|
|
|
|
class _TupleReduceSum(nn.Cell): |
|
|
|
""" |
|
|
|
ReduceSum two tuple of data. |
|
|
|
""" |
|
|
|
def __init__(self, tuplen): |
|
|
|
super(_TupleReduceSum, self).__init__() |
|
|
|
self.reduce_sum = P.ReduceSum() |
|
|
|
self.hyper_map = C.HyperMap() |
|
|
|
self.input2 = (0,) * tuplen |
|
|
|
|
|
|
|
def construct(self, input1): |
|
|
|
"""use `HyperMap` to calculate tuple""" |
|
|
|
out = self.hyper_map(self.reduce_sum, input1, self.input2) |
|
|
|
return out |