|
|
@@ -47,10 +47,15 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow |
|
|
|
from mindspore.nn import Cell |
|
|
|
from mindspore import ParameterTuple |
|
|
|
|
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
|
from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater |
|
|
|
from mindarmour.utils._check_param import check_param_type |
|
|
|
from mindarmour.utils._check_param import check_value_positive |
|
|
|
from mindarmour.utils._check_param import check_int_positive |
|
|
|
|
|
|
|
LOGGER = LogUtil.get_instance() |
|
|
|
TAG = 'DP model' |
|
|
|
|
|
|
|
GRADIENT_CLIP_TYPE = 1 |
|
|
|
_grad_scale = C.MultitypeFuncGraph("grad_scale") |
|
|
|
_reciprocal = P.Reciprocal() |
|
|
@@ -105,13 +110,19 @@ class DPModel(Model): |
|
|
|
norm_clip = check_param_type('norm_clip', norm_clip, float) |
|
|
|
self._norm_clip = check_value_positive('norm_clip', norm_clip) |
|
|
|
if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: |
|
|
|
raise ValueError('DPOptimizer is not supported while mech is not None') |
|
|
|
msg = 'DPOptimizer is not supported while mech is not None' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
if mech is None: |
|
|
|
if "DPOptimizer" in kwargs['optimizer'].__class__.__name__: |
|
|
|
if context.get_context('mode') != context.PYNATIVE_MODE: |
|
|
|
raise ValueError('DPOptimizer just support pynative mode currently.') |
|
|
|
msg = 'DPOptimizer just support pynative mode currently.' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
else: |
|
|
|
raise ValueError('DPModel should set mech or DPOptimizer configure, please refer to example.') |
|
|
|
msg = 'DPModel should set mech or DPOptimizer configure, please refer to example.' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
self._mech = mech |
|
|
|
super(DPModel, self).__init__(**kwargs) |
|
|
|
|
|
|
@@ -163,10 +174,11 @@ class DPModel(Model): |
|
|
|
if update_cell is not None: |
|
|
|
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow. |
|
|
|
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU": |
|
|
|
raise ValueError("Only `loss_scale_manager=None` and " |
|
|
|
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" |
|
|
|
"are supported in current version. If you use `O2` option, please" |
|
|
|
"use `loss_scale_manager=None` or `FixedLossScaleManager`") |
|
|
|
msg = "Only `loss_scale_manager=None` and `loss_scale_manager=FixedLossScaleManager(drop_overflow" \ |
|
|
|
"_update=False)` are supported in current version. If you use `O2` option, please use " \ |
|
|
|
"`loss_scale_manager=None` or `FixedLossScaleManager`" |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
network = _TrainOneStepWithLossScaleCell(network, |
|
|
|
optimizer, |
|
|
|
scale_update_cell=update_cell, |
|
|
@@ -174,6 +186,7 @@ class DPModel(Model): |
|
|
|
norm_clip=self._norm_clip, |
|
|
|
mech=self._mech).set_train() |
|
|
|
return network |
|
|
|
|
|
|
|
network = _TrainOneStepCell(network, |
|
|
|
optimizer, |
|
|
|
loss_scale, |
|
|
@@ -182,47 +195,48 @@ class DPModel(Model): |
|
|
|
mech=self._mech).set_train() |
|
|
|
return network |
|
|
|
|
|
|
|
def _build_train_network(self): |
|
|
|
"""Build train network""" |
|
|
|
network = self._network |
|
|
|
if self._micro_batches: |
|
|
|
if self._optimizer: |
|
|
|
if self._loss_scale_manager_set: |
|
|
|
network = self._amp_build_train_network(network, |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level, |
|
|
|
loss_scale_manager=self._loss_scale_manager, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
else: |
|
|
|
network = self._amp_build_train_network(network, |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
elif self._loss_fn: |
|
|
|
network = nn.WithLossCell(network, self._loss_fn) |
|
|
|
else: |
|
|
|
if self._optimizer: |
|
|
|
if self._loss_scale_manager_set: |
|
|
|
network = amp.build_train_network(network, |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level, |
|
|
|
loss_scale_manager=self._loss_scale_manager, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
else: |
|
|
|
network = amp.build_train_network(network, |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
elif self._loss_fn: |
|
|
|
network = nn.WithLossCell(network, self._loss_fn) |
|
|
|
|
|
|
|
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): |
|
|
|
network.set_auto_parallel() |
|
|
|
return network |
|
|
|
|
|
|
|
def _build_train_network(self): |
|
|
|
"""Build train network""" |
|
|
|
network = self._network |
|
|
|
if self._micro_batches: |
|
|
|
if self._optimizer: |
|
|
|
if self._loss_scale_manager_set: |
|
|
|
network = self._amp_build_train_network(network, |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level, |
|
|
|
loss_scale_manager=self._loss_scale_manager, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
else: |
|
|
|
network = self._amp_build_train_network(network, |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
elif self._loss_fn: |
|
|
|
network = nn.WithLossCell(network, self._loss_fn) |
|
|
|
else: |
|
|
|
if self._optimizer: |
|
|
|
if self._loss_scale_manager_set: |
|
|
|
network = amp.build_train_network(network, |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level, |
|
|
|
loss_scale_manager=self._loss_scale_manager, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
else: |
|
|
|
network = amp.build_train_network(network, |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
elif self._loss_fn: |
|
|
|
network = nn.WithLossCell(network, self._loss_fn) |
|
|
|
|
|
|
|
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): |
|
|
|
network.set_auto_parallel() |
|
|
|
return network |
|
|
|
|
|
|
|
|
|
|
|
class _ClipGradients(nn.Cell): |
|
|
@@ -358,6 +372,13 @@ class _TrainOneStepWithLossScaleCell(Cell): |
|
|
|
self._hyper_map = C.HyperMap() |
|
|
|
self._micro_float = Tensor(micro_batches, mstype.float32) |
|
|
|
|
|
|
|
self._mech_param_updater = None |
|
|
|
if self._mech is not None and self._mech._decay_policy is not None: |
|
|
|
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy, |
|
|
|
decay_rate=self._mech._noise_decay_rate, |
|
|
|
cur_params=self._mech._noise_multiplier, |
|
|
|
init_params=self._mech._initial_noise_multiplier) |
|
|
|
|
|
|
|
def construct(self, data, label, sens=None): |
|
|
|
""" |
|
|
|
construct a compute flow. |
|
|
@@ -380,14 +401,14 @@ class _TrainOneStepWithLossScaleCell(Cell): |
|
|
|
record_labels = self._split(label) |
|
|
|
# first index |
|
|
|
loss = self.network(record_datas[0], record_labels[0]) |
|
|
|
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) |
|
|
|
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) |
|
|
|
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) |
|
|
|
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) |
|
|
|
grads = record_grad |
|
|
|
total_loss = loss |
|
|
|
for i in range(1, self._micro_batches): |
|
|
|
loss = self.network(record_datas[i], record_labels[i]) |
|
|
|
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) |
|
|
|
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) |
|
|
|
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) |
|
|
|
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) |
|
|
|
grads = self._tuple_add(grads, record_grad) |
|
|
@@ -398,6 +419,10 @@ class _TrainOneStepWithLossScaleCell(Cell): |
|
|
|
grad_noise = self._hyper_map(self._mech, grads) |
|
|
|
grads = self._tuple_add(grads, grad_noise) |
|
|
|
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) |
|
|
|
# update mech parameters |
|
|
|
if self._mech_param_updater is not None: |
|
|
|
multiplier = self._mech_param_updater() |
|
|
|
loss = F.depend(loss, multiplier) |
|
|
|
|
|
|
|
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) |
|
|
|
# apply grad reducer on grads |
|
|
@@ -474,6 +499,10 @@ class _TrainOneStepCell(Cell): |
|
|
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) |
|
|
|
|
|
|
|
# dp params |
|
|
|
if micro_batches is None: |
|
|
|
msg = 'micro_batches must give in differential privacy, but got value: {}'.format(micro_batches) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
self._micro_batches = micro_batches |
|
|
|
norm_clip = check_param_type('norm_clip', norm_clip, float) |
|
|
|
self._l2_norm = check_value_positive('norm_clip', norm_clip) |
|
|
@@ -484,6 +513,13 @@ class _TrainOneStepCell(Cell): |
|
|
|
self._hyper_map = C.HyperMap() |
|
|
|
self._micro_float = Tensor(micro_batches, mstype.float32) |
|
|
|
|
|
|
|
self._mech_param_updater = None |
|
|
|
if self._mech is not None and self._mech._decay_policy is not None: |
|
|
|
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy, |
|
|
|
decay_rate=self._mech._noise_decay_rate, |
|
|
|
cur_params=self._mech._noise_multiplier, |
|
|
|
init_params=self._mech._initial_noise_multiplier) |
|
|
|
|
|
|
|
def construct(self, data, label): |
|
|
|
""" |
|
|
|
construct a compute flow. |
|
|
@@ -510,6 +546,10 @@ class _TrainOneStepCell(Cell): |
|
|
|
grad_noise = self._hyper_map(self._mech, grads) |
|
|
|
grads = self._tuple_add(grads, grad_noise) |
|
|
|
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) |
|
|
|
# update mech parameters |
|
|
|
if self._mech_param_updater is not None: |
|
|
|
multiplier = self._mech_param_updater() |
|
|
|
loss = F.depend(loss, multiplier) |
|
|
|
|
|
|
|
if self.reducer_flag: |
|
|
|
# apply grad reducer on grads |
|
|
|