@@ -20,7 +20,7 @@ from easydict import EasyDict as edict | |||
mnist_cfg = edict({ | |||
'num_classes': 10, # the number of classes of model's output | |||
'lr': 0.1, # the learning rate of model's optimizer | |||
'lr': 0.01, # the learning rate of model's optimizer | |||
'momentum': 0.9, # the momentum value of model's optimizer | |||
'epoch_size': 10, # training epochs | |||
'batch_size': 256, # batch size for training | |||
@@ -33,8 +33,13 @@ mnist_cfg = edict({ | |||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||
'micro_batches': 16, # the number of small batches split from an original batch | |||
'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters | |||
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training | |||
'initial_noise_multiplier': 0.5, # the initial multiplication coefficient of the noise added to training | |||
# parameters' gradients | |||
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||
'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||
'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training | |||
'clip_decay_policy': 'Linear', # Decay policy of adaptive clipping, decay_policy must be in ['Linear', 'Geometric']. | |||
'clip_learning_rate': 0.001, # Learning rate of update norm clip. | |||
'target_unclipped_quantile': 0.9, # Target quantile of norm clip. | |||
'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction. | |||
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | |||
}) |
@@ -31,7 +31,8 @@ import mindspore.common.dtype as mstype | |||
from mindarmour.diff_privacy import DPModel | |||
from mindarmour.diff_privacy import PrivacyMonitorFactory | |||
from mindarmour.diff_privacy import MechanismsFactory | |||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||
from mindarmour.diff_privacy import ClipMechanismsFactory | |||
from mindarmour.utils.logger import LogUtil | |||
from lenet5_net import LeNet5 | |||
from lenet5_config import mnist_cfg as cfg | |||
@@ -87,11 +88,14 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||
if __name__ == "__main__": | |||
# This configure can run both in pynative mode and graph mode | |||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | |||
context.set_context(mode=context.GRAPH_MODE, | |||
device_target=cfg.device_target) | |||
network = LeNet5() | |||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, | |||
reduction="mean") | |||
config_ck = CheckpointConfig( | |||
save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||
directory='./trained_ckpt_file/', | |||
config=config_ck) | |||
@@ -102,17 +106,33 @@ if __name__ == "__main__": | |||
cfg.epoch_size) | |||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||
raise ValueError("Number of micro_batches should divide evenly batch_size") | |||
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training. | |||
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which | |||
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||
mech = MechanismsFactory().create(cfg.mechanisms, | |||
norm_bound=cfg.norm_clip, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
net_opt = nn.Momentum(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||
# and delta) while training. | |||
raise ValueError( | |||
"Number of micro_batches should divide evenly batch_size") | |||
# Create a factory class of DP noise mechanisms, this method is adding noise | |||
# in gradients while training. Initial_noise_multiplier is suggested to be | |||
# greater than 1.0, otherwise the privacy budget would be huge, which means | |||
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian' | |||
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' | |||
# mechanism while be constant with 'Gaussian' mechanism. | |||
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, | |||
norm_bound=cfg.norm_clip, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
# Create a factory class of clip mechanisms, this method is to adaptive clip | |||
# gradients while training, decay_policy support 'Linear' and 'Geometric', | |||
# learning_rate is the learning rate to update clip_norm, | |||
# target_unclipped_quantile is the target quantile of norm clip, | |||
# fraction_stddev is the stddev of Gaussian normal which used in | |||
# empirical_fraction, the formula is | |||
# $empirical_fraction + N(0, fraction_stddev)$. | |||
clip_mech = ClipMechanismsFactory().create(cfg.clip_mechanisms, | |||
decay_policy=cfg.clip_decay_policy, | |||
learning_rate=cfg.clip_learning_rate, | |||
target_unclipped_quantile=cfg.target_unclipped_quantile, | |||
fraction_stddev=cfg.fraction_stddev) | |||
net_opt = nn.Momentum(params=network.trainable_params(), | |||
learning_rate=cfg.lr, momentum=cfg.momentum) | |||
# Create a monitor for DP training. The function of the monitor is to | |||
# compute and print the privacy budget(eps and delta) while training. | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
@@ -121,20 +141,23 @@ if __name__ == "__main__": | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_clip=cfg.norm_clip, | |||
mech=mech, | |||
noise_mech=noise_mech, | |||
clip_mech=clip_mech, | |||
network=network, | |||
loss_fn=net_loss, | |||
optimizer=net_opt, | |||
metrics={"Accuracy": Accuracy()}) | |||
LOGGER.info(TAG, "============== Starting Training ==============") | |||
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | |||
model.train(cfg['epoch_size'], ds_train, | |||
callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | |||
dataset_sink_mode=cfg.dataset_sink_mode) | |||
LOGGER.info(TAG, "============== Starting Testing ==============") | |||
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | |||
param_dict = load_checkpoint(ckpt_file_name) | |||
load_param_into_net(network, param_dict) | |||
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) | |||
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), | |||
batch_size=cfg.batch_size) | |||
acc = model.eval(ds_eval, dataset_sink_mode=False) | |||
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) |
@@ -1,16 +1,20 @@ | |||
""" | |||
This module provide Differential Privacy feature to protect user privacy. | |||
""" | |||
from .mechanisms.mechanisms import GaussianRandom | |||
from .mechanisms.mechanisms import NoiseGaussianRandom | |||
from .mechanisms.mechanisms import AdaGaussianRandom | |||
from .mechanisms.mechanisms import MechanismsFactory | |||
from .mechanisms.mechanisms import AdaClippingWithGaussianRandom | |||
from .mechanisms.mechanisms import NoiseMechanismsFactory | |||
from .mechanisms.mechanisms import ClipMechanismsFactory | |||
from .monitor.monitor import PrivacyMonitorFactory | |||
from .optimizer.optimizer import DPOptimizerClassFactory | |||
from .train.model import DPModel | |||
__all__ = ['GaussianRandom', | |||
__all__ = ['NoiseGaussianRandom', | |||
'AdaGaussianRandom', | |||
'MechanismsFactory', | |||
'AdaClippingWithGaussianRandom', | |||
'NoiseMechanismsFactory', | |||
'ClipMechanismsFactory', | |||
'PrivacyMonitorFactory', | |||
'DPOptimizerClassFactory', | |||
'DPModel'] |
@@ -28,11 +28,54 @@ from mindarmour.utils._check_param import check_param_in_range | |||
from mindarmour.utils.logger import LogUtil | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Defense' | |||
TAG = 'NoiseMechanism' | |||
class MechanismsFactory: | |||
""" Factory class of mechanisms""" | |||
class ClipMechanismsFactory: | |||
""" Factory class of clip mechanisms""" | |||
def __init__(self): | |||
pass | |||
@staticmethod | |||
def create(mech_name, *args, **kwargs): | |||
""" | |||
Args: | |||
mech_name(str): Clip noise generated strategy, support 'Gaussian' now. | |||
args(Union[float, str]): Parameters used for creating clip mechanisms. | |||
kwargs(Union[float, str]): Parameters used for creating clip | |||
mechanisms. | |||
Raises: | |||
NameError: `mech_name` must be in ['Gaussian']. | |||
Returns: | |||
Mechanisms, class of noise generated Mechanism. | |||
Examples: | |||
>>> decay_policy = 'Linear' | |||
>>> beta = Tensor(0.5, mstype.float32) | |||
>>> norm_clip = Tensor(1.0, mstype.float32) | |||
>>> beta_stddev = 0.1 | |||
>>> learning_rate = 0.1 | |||
>>> target_unclipped_quantile = 0.3 | |||
>>> clip_mechanism = ClipMechanismsFactory() | |||
>>> ada_clip = clip_mechanism.create('Gaussian', | |||
>>> decay_policy=decay_policy, | |||
>>> learning_rate=learning_rate, | |||
>>> target_unclipped_quantile=target_unclipped_quantile, | |||
>>> fraction_stddev=beta_stddev) | |||
>>> next_norm_clip = ada_clip(beta, norm_clip) | |||
""" | |||
if mech_name == 'Gaussian': | |||
return AdaClippingWithGaussianRandom(*args, **kwargs) | |||
raise NameError("The {} is not implement, please choose " | |||
"['Gaussian']".format(mech_name)) | |||
class NoiseMechanismsFactory: | |||
""" Factory class of noise mechanisms""" | |||
def __init__(self): | |||
pass | |||
@@ -56,42 +99,38 @@ class MechanismsFactory: | |||
Mechanisms, class of noise generated Mechanism. | |||
Examples: | |||
>>> class Net(nn.Cell): | |||
>>> def __init__(self): | |||
>>> super(Net, self).__init__() | |||
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') | |||
>>> self.bn = nn.BatchNorm2d(64) | |||
>>> self.relu = nn.ReLU() | |||
>>> self.flatten = nn.Flatten() | |||
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0 | |||
>>> | |||
>>> def construct(self, x): | |||
>>> x = self.conv(x) | |||
>>> x = self.bn(x) | |||
>>> x = self.relu(x) | |||
>>> x = self.flatten(x) | |||
>>> out = self.fc(x) | |||
>>> return out | |||
>>> norm_clip = 1.0 | |||
>>> initial_noise_multiplier = 1.5 | |||
>>> net = Net() | |||
>>> initial_noise_multiplier = 0.01 | |||
>>> network = LeNet5() | |||
>>> batch_size = 32 | |||
>>> batches = 128 | |||
>>> epochs = 1 | |||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||
>>> mech = MechanismsFactory().create('Gaussian', | |||
>>> norm_bound=norm_clip, | |||
>>> initial_noise_multiplier=initial_noise_multiplier) | |||
>>> noise_mech = NoiseMechanismsFactory().create('Gaussian', | |||
>>> norm_bound=norm_clip, | |||
>>> initial_noise_multiplier=initial_noise_multiplier) | |||
>>> clip_mech = ClipMechanismsFactory().create('Gaussian', | |||
>>> decay_policy='Linear', | |||
>>> learning_rate=0.01, | |||
>>> target_unclipped_quantile=0.9, | |||
>>> fraction_stddev=0.01) | |||
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, | |||
>>> momentum=0.9) | |||
>>> model = DPModel(micro_batches=2, | |||
>>> norm_clip=1.0, | |||
>>> mech=mech, | |||
>>> network=net, | |||
>>> clip_mech=clip_mech, | |||
>>> norm_clip=norm_clip, | |||
>>> noise_mech=noise_mech, | |||
>>> network=network, | |||
>>> loss_fn=loss, | |||
>>> optimizer=net_opt, | |||
>>> metrics=None) | |||
>>> dataset = get_dataset() | |||
>>> model.train(2, dataset) | |||
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||
>>> ['data', 'label']) | |||
>>> ms_ds.set_dataset_size(batch_size * batches) | |||
>>> model.train(epochs, ms_ds, dataset_sink_mode=False) | |||
""" | |||
if policy == 'Gaussian': | |||
return GaussianRandom(*args, **kwargs) | |||
return NoiseGaussianRandom(*args, **kwargs) | |||
if policy == 'AdaGaussian': | |||
return AdaGaussianRandom(*args, **kwargs) | |||
raise NameError("The {} is not implement, please choose " | |||
@@ -110,7 +149,7 @@ class Mechanisms(Cell): | |||
""" | |||
class GaussianRandom(Mechanisms): | |||
class NoiseGaussianRandom(Mechanisms): | |||
""" | |||
Gaussian noise generated mechanism. | |||
@@ -133,18 +172,21 @@ class GaussianRandom(Mechanisms): | |||
>>> gradients = Tensor([0.2, 0.9], mstype.float32) | |||
>>> norm_bound = 0.5 | |||
>>> initial_noise_multiplier = 1.5 | |||
>>> net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||
>>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) | |||
>>> res = net(gradients) | |||
>>> print(res) | |||
""" | |||
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0, policy=None): | |||
super(GaussianRandom, self).__init__() | |||
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0, | |||
policy=None): | |||
super(NoiseGaussianRandom, self).__init__() | |||
self._norm_bound = check_value_positive('norm_bound', norm_bound) | |||
self._norm_bound = Tensor(norm_bound, mstype.float32) | |||
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | |||
initial_noise_multiplier) | |||
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||
self._initial_noise_multiplier = check_value_positive( | |||
'initial_noise_multiplier', | |||
initial_noise_multiplier) | |||
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, | |||
mstype.float32) | |||
self._mean = Tensor(0, mstype.float32) | |||
self._normal = P.Normal(seed=seed) | |||
self._decay_policy = policy | |||
@@ -201,17 +243,20 @@ class AdaGaussianRandom(Mechanisms): | |||
noise_decay_rate=6e-4, decay_policy='Time', seed=0): | |||
super(AdaGaussianRandom, self).__init__() | |||
norm_bound = check_value_positive('norm_bound', norm_bound) | |||
initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | |||
initial_noise_multiplier) | |||
initial_noise_multiplier = check_value_positive( | |||
'initial_noise_multiplier', | |||
initial_noise_multiplier) | |||
self._norm_bound = Tensor(norm_bound, mstype.float32) | |||
initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||
initial_noise_multiplier = Tensor(initial_noise_multiplier, | |||
mstype.float32) | |||
self._initial_noise_multiplier = Parameter(initial_noise_multiplier, | |||
name='initial_noise_multiplier') | |||
self._noise_multiplier = Parameter(initial_noise_multiplier, | |||
name='noise_multiplier') | |||
self._mean = Tensor(0, mstype.float32) | |||
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | |||
noise_decay_rate = check_param_type('noise_decay_rate', | |||
noise_decay_rate, float) | |||
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | |||
self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32) | |||
if decay_policy not in ['Time', 'Step', 'Exp']: | |||
@@ -232,7 +277,9 @@ class AdaGaussianRandom(Mechanisms): | |||
Tensor, generated noise with shape like given gradients. | |||
""" | |||
shape = P.Shape()(gradients) | |||
noise = self._normal(shape, self._mean, self._mul(self._noise_multiplier, self._norm_bound)) | |||
noise = self._normal(shape, self._mean, | |||
self._mul(self._noise_multiplier, | |||
self._norm_bound)) | |||
return noise | |||
@@ -241,10 +288,14 @@ class _MechanismsParamsUpdater(Cell): | |||
Update mechanisms parameters, the parameters will refresh in train period. | |||
Args: | |||
policy(str): Pass in by the mechanisms class, mechanisms parameters update policy. | |||
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for controlling the decay size. | |||
cur_noise_multiplier(Parameter): Pass in by the mechanisms class, current params value in this time. | |||
init_noise_multiplier(Parameter):Pass in by the mechanisms class, initial params value to be updated. | |||
policy(str): Pass in by the mechanisms class, mechanisms parameters | |||
update policy. | |||
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for | |||
controlling the decay size. | |||
cur_noise_multiplier(Parameter): Pass in by the mechanisms class, | |||
current params value in this time. | |||
init_noise_multiplier(Parameter):Pass in by the mechanisms class, | |||
initial params value to be updated. | |||
Returns: | |||
Tuple, next params value. | |||
@@ -281,5 +332,100 @@ class _MechanismsParamsUpdater(Cell): | |||
next_noise_multiplier = self._assign(self._cur_noise_multiplier, | |||
self._mul(temp, self._cur_noise_multiplier)) | |||
else: | |||
next_noise_multiplier = self._assign(self._cur_noise_multiplier, self._div(self._one, self._exp(self._one))) | |||
next_noise_multiplier = self._assign(self._cur_noise_multiplier, | |||
self._div(self._one, self._exp(self._one))) | |||
return next_noise_multiplier | |||
class AdaClippingWithGaussianRandom(Cell): | |||
""" | |||
Adaptive clipping. If `decay_policy` is 'Linear', the update formula is | |||
$ norm_clip = norm_clip - learning_rate*(beta-target_unclipped_quantile)$. | |||
`decay_policy` is 'Geometric', the update formula is | |||
$ norm_clip = norm_clip*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$. | |||
where beta is the empirical fraction of samples with the value at most | |||
`target_unclipped_quantile`. | |||
Args: | |||
decay_policy(str): Decay policy of adaptive clipping, decay_policy must | |||
be in ['Linear', 'Geometric']. Default: Linear. | |||
learning_rate(float): Learning rate of update norm clip. Default: 0.01. | |||
target_unclipped_quantile(float): Target quantile of norm clip. Default: 0.9. | |||
fraction_stddev(float): The stddev of Gaussian normal which used in | |||
empirical_fraction, the formula is $empirical_fraction + N(0, fraction_stddev)$. | |||
seed(int): Original random seed, if seed=0 random normal will use secure | |||
random number. IF seed!=0 random normal will generate values using | |||
given seed. Default: 0. | |||
Returns: | |||
Tensor, undated norm clip . | |||
Examples: | |||
>>> decay_policy = 'Linear' | |||
>>> beta = Tensor(0.5, mstype.float32) | |||
>>> norm_clip = Tensor(1.0, mstype.float32) | |||
>>> beta_stddev = 0.01 | |||
>>> learning_rate = 0.001 | |||
>>> target_unclipped_quantile = 0.9 | |||
>>> ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||
>>> learning_rate=learning_rate, | |||
>>> target_unclipped_quantile=target_unclipped_quantile, | |||
>>> fraction_stddev=beta_stddev) | |||
>>> next_norm_clip = ada_clip(beta, norm_clip) | |||
""" | |||
def __init__(self, decay_policy='Linear', learning_rate=0.001, | |||
target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0): | |||
super(AdaClippingWithGaussianRandom, self).__init__() | |||
if decay_policy not in ['Linear', 'Geometric']: | |||
msg = "decay policy of adaptive clip must be in ['Linear', 'Geometric'], \ | |||
but got: {}".format(decay_policy) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
self._decay_policy = decay_policy | |||
learning_rate = check_param_type('learning_rate', learning_rate, float) | |||
learning_rate = check_value_positive('learning_rate', learning_rate) | |||
self._learning_rate = Tensor(learning_rate, mstype.float32) | |||
fraction_stddev = check_param_type('fraction_stddev', fraction_stddev, float) | |||
self._fraction_stddev = Tensor(fraction_stddev, mstype.float32) | |||
target_unclipped_quantile = check_param_type('target_unclipped_quantile', | |||
target_unclipped_quantile, | |||
float) | |||
self._target_unclipped_quantile = Tensor(target_unclipped_quantile, | |||
mstype.float32) | |||
self._zero = Tensor(0, mstype.float32) | |||
self._add = P.TensorAdd() | |||
self._sub = P.Sub() | |||
self._mul = P.Mul() | |||
self._exp = P.Exp() | |||
self._normal = P.Normal(seed=seed) | |||
def construct(self, empirical_fraction, norm_clip): | |||
""" | |||
Update value of norm_clip. | |||
Args: | |||
empirical_fraction(Tensor): empirical fraction of samples with the | |||
value at most `target_unclipped_quantile`. | |||
norm_clip(Tensor): Clipping bound for the l2 norm of the gradients. | |||
Returns: | |||
Tensor, generated noise with shape like given gradients. | |||
""" | |||
fraction_noise = self._normal((1,), self._zero, self._fraction_stddev) | |||
empirical_fraction = self._add(empirical_fraction, fraction_noise) | |||
if self._decay_policy == 'Linear': | |||
grad_clip = self._sub(empirical_fraction, | |||
self._target_unclipped_quantile) | |||
next_norm_clip = self._sub(norm_clip, | |||
self._mul(self._learning_rate, grad_clip)) | |||
# decay_policy == 'Geometric' | |||
else: | |||
grad_clip = self._sub(empirical_fraction, | |||
self._target_unclipped_quantile) | |||
grad_clip = self._exp(self._mul(-self._learning_rate, grad_clip)) | |||
next_norm_clip = self._mul(norm_clip, grad_clip) | |||
return next_norm_clip |
@@ -22,7 +22,7 @@ from mindspore.ops import functional as F | |||
from mindspore.common import dtype as mstype | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.diff_privacy import MechanismsFactory | |||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||
from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater | |||
from mindarmour.utils._check_param import check_int_positive | |||
@@ -70,7 +70,7 @@ class DPOptimizerClassFactory: | |||
""" | |||
def __init__(self, micro_batches=2): | |||
self._mech_factory = MechanismsFactory() | |||
self._mech_factory = NoiseMechanismsFactory() | |||
self.mech = None | |||
self._micro_batches = check_int_positive('micro_batches', micro_batches) | |||
@@ -48,7 +48,8 @@ from mindspore.nn import Cell | |||
from mindspore import ParameterTuple | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater | |||
from mindarmour.diff_privacy.mechanisms.mechanisms import \ | |||
_MechanismsParamsUpdater | |||
from mindarmour.utils._check_param import check_param_type | |||
from mindarmour.utils._check_param import check_value_positive | |||
from mindarmour.utils._check_param import check_int_positive | |||
@@ -64,7 +65,7 @@ _reciprocal = P.Reciprocal() | |||
@_grad_scale.register("Tensor", "Tensor") | |||
def tensor_grad_scale(scale, grad): | |||
""" grad scaling """ | |||
return grad * F.cast(_reciprocal(scale), F.dtype(grad)) | |||
return grad*F.cast(_reciprocal(scale), F.dtype(grad)) | |||
class DPModel(Model): | |||
@@ -72,9 +73,14 @@ class DPModel(Model): | |||
This class is overload mindspore.train.model.Model. | |||
Args: | |||
micro_batches (int): The number of small batches split from an original batch. Default: 2. | |||
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | |||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||
micro_batches (int): The number of small batches split from an original | |||
batch. Default: 2. | |||
norm_clip (float): Use to clip the bound, if set 1, will retun the | |||
original data. Default: 1.0. | |||
noise_mech (Mechanisms): The object can generate the different type of | |||
noise. Default: None. | |||
clip_mech (Mechanisms): The object is used to update the adaptive clip . | |||
Default: None. | |||
Examples: | |||
>>> norm_clip = 1.0 | |||
@@ -89,63 +95,82 @@ class DPModel(Model): | |||
>>> factory_opt.set_mechanisms('Gaussian', | |||
>>> norm_bound=norm_clip, | |||
>>> initial_noise_multiplier=initial_noise_multiplier) | |||
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), | |||
>>> learning_rate=0.1, momentum=0.9) | |||
>>> clip_mech = ClipMechanismsFactory().create('Gaussian', | |||
>>> decay_policy='Linear', | |||
>>> learning_rate=0.01, | |||
>>> target_unclipped_quantile=0.9, | |||
>>> fraction_stddev=0.01) | |||
>>> model = DPModel(micro_batches=micro_batches, | |||
>>> norm_clip=norm_clip, | |||
>>> mech=None, | |||
>>> clip_mech=clip_mech, | |||
>>> noise_mech=None, | |||
>>> network=network, | |||
>>> loss_fn=loss, | |||
>>> optimizer=net_opt, | |||
>>> metrics=None) | |||
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||
>>> ms_ds.set_dataset_size(batch_size * batches) | |||
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||
>>> ['data', 'label']) | |||
>>> ms_ds.set_dataset_size(batch_size*batches) | |||
>>> model.train(epochs, ms_ds, dataset_sink_mode=False) | |||
""" | |||
def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs): | |||
def __init__(self, micro_batches=2, norm_clip=1.0, noise_mech=None, | |||
clip_mech=None, **kwargs): | |||
if micro_batches: | |||
self._micro_batches = check_int_positive('micro_batches', micro_batches) | |||
self._micro_batches = check_int_positive('micro_batches', | |||
micro_batches) | |||
else: | |||
self._micro_batches = None | |||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||
self._norm_clip = check_value_positive('norm_clip', norm_clip) | |||
if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||
msg = 'DPOptimizer is not supported while mech is not None' | |||
norm_clip = check_value_positive('norm_clip', norm_clip) | |||
norm_clip = Tensor(norm_clip, mstype.float32) | |||
self._norm_clip = Parameter(norm_clip, 'norm_clip') | |||
if noise_mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||
msg = 'DPOptimizer is not supported while noise_mech is not None' | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
if mech is None: | |||
if noise_mech is None: | |||
if "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||
if context.get_context('mode') != context.PYNATIVE_MODE: | |||
msg = 'DPOptimizer just support pynative mode currently.' | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
else: | |||
msg = 'DPModel should set mech or DPOptimizer configure, please refer to example.' | |||
msg = 'DPModel should set noise_mech or DPOptimizer configure, ' \ | |||
'please refer to example.' | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
self._mech = mech | |||
self._noise_mech = noise_mech | |||
if clip_mech is not None: | |||
self._clip_mech = clip_mech | |||
super(DPModel, self).__init__(**kwargs) | |||
def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): | |||
def _amp_build_train_network(self, network, optimizer, loss_fn=None, | |||
level='O0', **kwargs): | |||
""" | |||
Build the mixed precision training cell automatically. | |||
Args: | |||
network (Cell): Definition of the network. | |||
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside. | |||
Default: None. | |||
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, | |||
the `network` should have the loss inside. Default: None. | |||
optimizer (Optimizer): Optimizer to update the Parameter. | |||
level (str): Supports [O0, O2]. Default: "O0". | |||
- O0: Do not change. | |||
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32, | |||
using dynamic loss scale. | |||
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. | |||
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. | |||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. | |||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else | |||
scale the loss by LossScaleManager. If set, overwrite the level setting. | |||
- O2: Cast network to float16, keep batchnorm and `loss_fn` | |||
(if set) run in float32, using dynamic loss scale. | |||
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` | |||
or `mstype.float32`. If set to `mstype.float16`, use `float16` | |||
mode to train. If set, overwrite the level setting. | |||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, | |||
overwrite the level setting. | |||
loss_scale_manager (Union[None, LossScaleManager]): If None, not | |||
scale the loss, or else scale the loss by LossScaleManager. | |||
If set, overwrite the level setting. | |||
""" | |||
validator.check_value_type('network', network, nn.Cell, None) | |||
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) | |||
@@ -161,9 +186,11 @@ class DPModel(Model): | |||
_do_keep_batchnorm_fp32(network) | |||
if loss_fn: | |||
network = _add_loss_network(network, loss_fn, config.cast_model_type) | |||
network = _add_loss_network(network, loss_fn, | |||
config.cast_model_type) | |||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
if _get_parallel_mode() in ( | |||
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
network = _VirtualDatasetCell(network) | |||
loss_scale = 1.0 | |||
@@ -173,9 +200,12 @@ class DPModel(Model): | |||
update_cell = loss_scale_manager.get_update_cell() | |||
if update_cell is not None: | |||
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow. | |||
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU": | |||
msg = "Only `loss_scale_manager=None` and `loss_scale_manager=FixedLossScaleManager(drop_overflow" \ | |||
"_update=False)` are supported in current version. If you use `O2` option, please use " \ | |||
if not context.get_context("enable_ge") and context.get_context( | |||
"device_target") == "CPU": | |||
msg = "Only `loss_scale_manager=None` and " \ | |||
"`loss_scale_manager=FixedLossScaleManager(drop_overflow" \ | |||
"_update=False)` are supported in current version. " \ | |||
"If you use `O2` option, please use " \ | |||
"`loss_scale_manager=None` or `FixedLossScaleManager`" | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
@@ -184,15 +214,17 @@ class DPModel(Model): | |||
scale_update_cell=update_cell, | |||
micro_batches=self._micro_batches, | |||
norm_clip=self._norm_clip, | |||
mech=self._mech).set_train() | |||
clip_mech=self._clip_mech, | |||
noise_mech=self._noise_mech).set_train() | |||
return network | |||
network = _TrainOneStepCell(network, | |||
optimizer, | |||
self._norm_clip, | |||
loss_scale, | |||
micro_batches=self._micro_batches, | |||
norm_clip=self._norm_clip, | |||
mech=self._mech).set_train() | |||
clip_mech=self._clip_mech, | |||
noise_mech=self._noise_mech).set_train() | |||
return network | |||
def _build_train_network(self): | |||
@@ -233,7 +265,8 @@ class DPModel(Model): | |||
elif self._loss_fn: | |||
network = nn.WithLossCell(network, self._loss_fn) | |||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, | |||
ParallelMode.AUTO_PARALLEL): | |||
network.set_auto_parallel() | |||
return network | |||
@@ -267,11 +300,10 @@ class _ClipGradients(nn.Cell): | |||
new_grads = () | |||
for grad in grads: | |||
if clip_type == 0: | |||
t = C.clip_by_value(grad, F.tuple_to_array((-clip_value,)), | |||
F.tuple_to_array((clip_value,))) | |||
norm = C.clip_by_value(grad, -clip_value, clip_value) | |||
else: | |||
t = self.clip_by_norm(grad, F.tuple_to_array((clip_value,))) | |||
new_grads = new_grads + (t,) | |||
norm = self.clip_by_norm(grad, clip_value) | |||
new_grads = new_grads + (norm,) | |||
return new_grads | |||
@@ -292,20 +324,27 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
r""" | |||
Network training with loss scaling. | |||
This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update | |||
Cell as args. The loss scale value can be updated in both host side or device side. The | |||
TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input | |||
data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should | |||
be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`. | |||
If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored. | |||
This is a training step with loss scaling. It takes a network, an optimizer | |||
and possibly a scale update Cell as args. The loss scale value can be | |||
updated in both host side or device side. The TrainOneStepWithLossScaleCell | |||
will be compiled to be graph which takes `data`, `label`, `sens` as input | |||
data. The `sens` is acting as loss scaling value. If you want to update it | |||
on host side, the value should be provided. If `sens` is not given, the loss | |||
scale update logic should be provied by `scale_update_cell`. If | |||
`scale_update_cell` is not None and `sens` is provided, the | |||
`scale_update_cell` will be ignored. | |||
Args: | |||
network (Cell): The training network. | |||
optimizer (Cell): Optimizer for updating the weights. | |||
scale_update_cell(Cell): The loss scaling update logic cell. Default: None. | |||
micro_batches (int): The number of small batches split from an original batch. Default: None. | |||
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. | |||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||
scale_update_cell(Cell): The loss scaling update logic cell. | |||
Default: None. | |||
micro_batches (int): The number of small batches split from an original | |||
batch. Default: None. | |||
norm_clip (Tensor): Use to clip the bound, if set 1, will return the | |||
original data. Default: 1.0. | |||
noise_mech (Mechanisms): The object can generate the different type of | |||
noise. Default: None. | |||
Inputs: | |||
- **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||
@@ -320,7 +359,9 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
- **loss_scale** (Tensor) - Tensor with shape :math:`()`. | |||
""" | |||
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, norm_clip=1.0, mech=None): | |||
def __init__(self, network, optimizer, scale_update_cell=None, | |||
micro_batches=None, norm_clip=1.0, noise_mech=None, | |||
clip_mech=None): | |||
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||
self.network = network | |||
self.network.set_grad() | |||
@@ -346,39 +387,54 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
self.allreduce = P.AllReduce() | |||
self.parallel_mode = _get_parallel_mode() | |||
self.grad_reducer = F.identity | |||
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] | |||
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, | |||
ParallelMode.HYBRID_PARALLEL] | |||
if self.reducer_flag: | |||
mean = _get_mirror_mean() | |||
degree = _get_device_num() | |||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, | |||
mean, degree) | |||
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | |||
self.loss_scale = None | |||
self.loss_scaling_manager = scale_update_cell | |||
if scale_update_cell: | |||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||
name="loss_scale") | |||
self.loss_scale = Parameter( | |||
Tensor(scale_update_cell.get_loss_scale(), | |||
dtype=mstype.float32), | |||
name="loss_scale") | |||
self.add_flags(has_effect=True) | |||
# dp params | |||
self._micro_batches = micro_batches | |||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||
self._l2_norm = check_value_positive('norm_clip', norm_clip) | |||
self._norm_clip = norm_clip | |||
self._split = P.Split(0, self._micro_batches) | |||
self._clip_by_global_norm = _ClipGradients() | |||
self._mech = mech | |||
self._noise_mech = noise_mech | |||
self._clip_mech = clip_mech | |||
self._add = P.TensorAdd() | |||
self._norm = nn.Norm() | |||
self._tuple_add = _TupleAdd() | |||
self._hyper_map = C.HyperMap() | |||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||
self._mech_param_updater = None | |||
if self._mech is not None and self._mech._decay_policy is not None: | |||
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy, | |||
decay_rate=self._mech._noise_decay_rate, | |||
cur_noise_multiplier= | |||
self._mech._noise_multiplier, | |||
init_noise_multiplier= | |||
self._mech._initial_noise_multiplier) | |||
self._zero = Tensor(0, mstype.float32) | |||
self._assign = P.Assign() | |||
self._div = P.Div() | |||
self._sqrt = P.Sqrt() | |||
self._reduce_sum = P.ReduceSum() | |||
self._square_all = P.Square() | |||
self._less = P.Less() | |||
self._cast = P.Cast() | |||
self._noise_mech_param_updater = None | |||
if self._noise_mech is not None and self._noise_mech._decay_policy is not None: | |||
self._noise_mech_param_updater = _MechanismsParamsUpdater( | |||
policy=self._noise_mech._decay_policy, | |||
decay_rate=self._noise_mech._noise_decay_rate, | |||
cur_noise_multiplier= | |||
self._noise_mech._noise_multiplier, | |||
init_noise_multiplier= | |||
self._noise_mech._initial_noise_multiplier) | |||
def construct(self, data, label, sens=None): | |||
""" | |||
@@ -402,30 +458,62 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
record_labels = self._split(label) | |||
# first index | |||
loss = self.network(record_datas[0], record_labels[0]) | |||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) | |||
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, | |||
F.dtype(loss)) | |||
record_grad = self.grad(self.network, weights)(record_datas[0], | |||
record_labels[0], | |||
scaling_sens_filled) | |||
beta = self._zero | |||
square_sum = self._zero | |||
for grad in record_grad: | |||
square_sum = self._add(square_sum, | |||
self._reduce_sum(self._square_all(grad))) | |||
norm_grad = self._sqrt(square_sum) | |||
beta = self._add(beta, | |||
self._cast(self._less(norm_grad, self._norm_clip), | |||
mstype.float32)) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, | |||
self._norm_clip) | |||
grads = record_grad | |||
total_loss = loss | |||
for i in range(1, self._micro_batches): | |||
loss = self.network(record_datas[i], record_labels[i]) | |||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) | |||
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, | |||
F.dtype(loss)) | |||
record_grad = self.grad(self.network, weights)(record_datas[i], | |||
record_labels[i], | |||
scaling_sens_filled) | |||
square_sum = self._zero | |||
for grad in record_grad: | |||
square_sum = self._add(square_sum, | |||
self._reduce_sum(self._square_all(grad))) | |||
norm_grad = self._sqrt(square_sum) | |||
beta = self._add(beta, | |||
self._cast(self._less(norm_grad, self._norm_clip), | |||
mstype.float32)) | |||
record_grad = self._clip_by_global_norm(record_grad, | |||
GRADIENT_CLIP_TYPE, | |||
self._norm_clip) | |||
grads = self._tuple_add(grads, record_grad) | |||
total_loss = P.TensorAdd()(total_loss, loss) | |||
loss = P.Div()(total_loss, self._micro_float) | |||
beta = self._div(beta, self._micro_batches) | |||
if self._mech is not None: | |||
if self._noise_mech is not None: | |||
grad_noise_tuple = () | |||
for grad_item in grads: | |||
grad_noise = self._mech(grad_item) | |||
grad_noise_tuple = grad_noise_tuple + (grad_noise,) | |||
grads = self._tuple_add(grads, grad_noise_tuple) | |||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), | |||
grads) | |||
# update mech parameters | |||
if self._mech_param_updater is not None: | |||
multiplier = self._mech_param_updater() | |||
if self._noise_mech_param_updater is not None: | |||
multiplier = self._noise_mech_param_updater() | |||
loss = F.depend(loss, multiplier) | |||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) | |||
@@ -456,6 +544,10 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
else: | |||
opt = self.optimizer(grads) | |||
ret = (loss, cond, scaling_sens) | |||
if self._clip_mech is not None: | |||
next_norm_clip = self._clip_mech(beta, self._norm_clip) | |||
P.assign(self._norm_clip, next_norm_clip) | |||
return F.depend(ret, opt) | |||
@@ -463,17 +555,22 @@ class _TrainOneStepCell(Cell): | |||
r""" | |||
Network training package class. | |||
Wraps the network with an optimizer. The resulting Cell be trained with input data and label. | |||
Backward graph will be created in the construct function to do parameter updating. Different | |||
parallel modes are available to run the training. | |||
Wraps the network with an optimizer. The resulting Cell be trained with | |||
input data and label. Backward graph will be created in the construct | |||
function to do parameter updating. Different parallel modes are available | |||
to run the training. | |||
Args: | |||
network (Cell): The training network. | |||
optimizer (Cell): Optimizer for updating the weights. | |||
sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0. | |||
micro_batches (int): The number of small batches split from an original batch. Default: None. | |||
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. | |||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||
sens (Number): The scaling number to be filled as the input of back | |||
propagation. Default value is 1.0. | |||
micro_batches (int): The number of small batches split from an original | |||
batch. Default: None. | |||
norm_clip (Tensor): Use to clip the bound, if set 1, will return the | |||
original data. Default: 1.0. | |||
noise_mech (Mechanisms): The object can generate the different type | |||
of noise. Default: None. | |||
Inputs: | |||
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||
@@ -483,7 +580,9 @@ class _TrainOneStepCell(Cell): | |||
Tensor, a scalar Tensor with shape :math:`()`. | |||
""" | |||
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, norm_clip=1.0, mech=None): | |||
def __init__(self, network, optimizer, norm_clip=1.0, sens=1.0, | |||
micro_batches=None, | |||
noise_mech=None, clip_mech=None): | |||
super(_TrainOneStepCell, self).__init__(auto_prefix=False) | |||
self.network = network | |||
self.network.set_grad() | |||
@@ -495,36 +594,51 @@ class _TrainOneStepCell(Cell): | |||
self.reducer_flag = False | |||
self.grad_reducer = None | |||
parallel_mode = _get_parallel_mode() | |||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||
if parallel_mode in ( | |||
ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||
self.reducer_flag = True | |||
if self.reducer_flag: | |||
mean = _get_mirror_mean() | |||
degree = _get_device_num() | |||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, | |||
mean, degree) | |||
# dp params | |||
if micro_batches is None: | |||
msg = 'micro_batches must give in differential privacy, but got value: {}'.format(micro_batches) | |||
msg = 'micro_batches must give in differential privacy, but got value: {}'.format( | |||
micro_batches) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
self._micro_batches = micro_batches | |||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||
self._l2_norm = check_value_positive('norm_clip', norm_clip) | |||
self._norm_clip = norm_clip | |||
self._split = P.Split(0, self._micro_batches) | |||
self._clip_by_global_norm = _ClipGradients() | |||
self._mech = mech | |||
self._noise_mech = noise_mech | |||
self._clip_mech = clip_mech | |||
self._tuple_add = _TupleAdd() | |||
self._add = P.TensorAdd() | |||
self._norm = nn.Norm() | |||
self._hyper_map = C.HyperMap() | |||
self._zero = Tensor(0, mstype.float32) | |||
self._assign = P.Assign() | |||
self._div = P.Div() | |||
self._sqrt = P.Sqrt() | |||
self._reduce_sum = P.ReduceSum() | |||
self._square_all = P.Square() | |||
self._less = P.Less() | |||
self._cast = P.Cast() | |||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||
self._mech_param_updater = None | |||
if self._mech is not None and self._mech._decay_policy is not None: | |||
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy, | |||
decay_rate=self._mech._noise_decay_rate, | |||
cur_noise_multiplier= | |||
self._mech._noise_multiplier, | |||
init_noise_multiplier= | |||
self._mech._initial_noise_multiplier) | |||
self._noise_mech_param_updater = None | |||
if self._noise_mech is not None and self._noise_mech._decay_policy is not None: | |||
self._noise_mech_param_updater = _MechanismsParamsUpdater( | |||
policy=self._noise_mech._decay_policy, | |||
decay_rate=self._noise_mech._noise_decay_rate, | |||
cur_noise_multiplier= | |||
self._noise_mech._noise_multiplier, | |||
init_noise_multiplier= | |||
self._noise_mech._initial_noise_multiplier) | |||
def construct(self, data, label): | |||
""" | |||
@@ -535,32 +649,65 @@ class _TrainOneStepCell(Cell): | |||
record_labels = self._split(label) | |||
loss = self.network(record_datas[0], record_labels[0]) | |||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||
record_grad = self.grad(self.network, weights)(record_datas[0], | |||
record_labels[0], sens) | |||
beta = self._zero | |||
square_sum = self._zero | |||
for grad in record_grad: | |||
square_sum = self._add(square_sum, | |||
self._reduce_sum(self._square_all(grad))) | |||
norm_grad = self._sqrt(square_sum) | |||
beta = self._add(beta, | |||
self._cast(self._less(norm_grad, self._norm_clip), | |||
mstype.float32)) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, | |||
self._norm_clip) | |||
grads = record_grad | |||
total_loss = loss | |||
for i in range(1, self._micro_batches): | |||
loss = self.network(record_datas[i], record_labels[i]) | |||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||
record_grad = self.grad(self.network, weights)(record_datas[i], | |||
record_labels[i], | |||
sens) | |||
square_sum = self._zero | |||
for grad in record_grad: | |||
square_sum = self._add(square_sum, | |||
self._reduce_sum(self._square_all(grad))) | |||
norm_grad = self._sqrt(square_sum) | |||
beta = self._add(beta, | |||
self._cast(self._less(norm_grad, self._norm_clip), | |||
mstype.float32)) | |||
record_grad = self._clip_by_global_norm(record_grad, | |||
GRADIENT_CLIP_TYPE, | |||
self._norm_clip) | |||
grads = self._tuple_add(grads, record_grad) | |||
total_loss = P.TensorAdd()(total_loss, loss) | |||
loss = P.Div()(total_loss, self._micro_float) | |||
loss = self._div(total_loss, self._micro_float) | |||
beta = self._div(beta, self._micro_batches) | |||
if self._mech is not None: | |||
if self._noise_mech is not None: | |||
grad_noise_tuple = () | |||
for grad_item in grads: | |||
grad_noise = self._mech(grad_item) | |||
grad_noise = self._noise_mech(grad_item) | |||
grad_noise_tuple = grad_noise_tuple + (grad_noise,) | |||
grads = self._tuple_add(grads, grad_noise_tuple) | |||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), | |||
grads) | |||
# update mech parameters | |||
if self._mech_param_updater is not None: | |||
multiplier = self._mech_param_updater() | |||
if self._noise_mech_param_updater is not None: | |||
multiplier = self._noise_mech_param_updater() | |||
loss = F.depend(loss, multiplier) | |||
if self.reducer_flag: | |||
# apply grad reducer on grads | |||
grads = self.grad_reducer(grads) | |||
if self._clip_mech is not None: | |||
next_norm_clip = self._clip_mech(beta, self._norm_clip) | |||
self._norm_clip = self._assign(self._norm_clip, next_norm_clip) | |||
loss = F.depend(loss, next_norm_clip) | |||
return F.depend(loss, self.optimizer(grads)) |
@@ -19,9 +19,11 @@ import pytest | |||
from mindspore import context | |||
from mindspore import Tensor | |||
from mindspore.common import dtype as mstype | |||
from mindarmour.diff_privacy import GaussianRandom | |||
from mindarmour.diff_privacy import NoiseGaussianRandom | |||
from mindarmour.diff_privacy import AdaGaussianRandom | |||
from mindarmour.diff_privacy import MechanismsFactory | |||
from mindarmour.diff_privacy import AdaClippingWithGaussianRandom | |||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||
from mindarmour.diff_privacy import ClipMechanismsFactory | |||
@pytest.mark.level0 | |||
@@ -33,7 +35,7 @@ def test_graph_gaussian(): | |||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||
net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) | |||
res = net(grad) | |||
print(res) | |||
@@ -47,7 +49,7 @@ def test_pynative_gaussian(): | |||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||
net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) | |||
res = net(grad) | |||
print(res) | |||
@@ -80,13 +82,13 @@ def test_graph_factory(): | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
noise_mechanism = MechanismsFactory() | |||
noise_mechanism = NoiseMechanismsFactory() | |||
noise_construct = noise_mechanism.create('Gaussian', | |||
norm_bound, | |||
initial_noise_multiplier) | |||
noise = noise_construct(grad) | |||
print('Gaussian noise: ', noise) | |||
ada_mechanism = MechanismsFactory() | |||
ada_mechanism = NoiseMechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
@@ -124,13 +126,13 @@ def test_pynative_factory(): | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
noise_mechanism = MechanismsFactory() | |||
noise_mechanism = NoiseMechanismsFactory() | |||
noise_construct = noise_mechanism.create('Gaussian', | |||
norm_bound, | |||
initial_noise_multiplier) | |||
noise = noise_construct(grad) | |||
print('Gaussian noise: ', noise) | |||
ada_mechanism = MechanismsFactory() | |||
ada_mechanism = NoiseMechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
@@ -151,7 +153,7 @@ def test_pynative_exponential(): | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Exp' | |||
ada_mechanism = MechanismsFactory() | |||
ada_mechanism = NoiseMechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
@@ -172,7 +174,7 @@ def test_graph_exponential(): | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Exp' | |||
ada_mechanism = MechanismsFactory() | |||
ada_mechanism = NoiseMechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
@@ -180,3 +182,107 @@ def test_graph_exponential(): | |||
decay_policy=decay_policy) | |||
ada_noise = ada_noise_construct(grad) | |||
print('ada noise: ', ada_noise) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_ada_clip_gaussian_random_pynative(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
decay_policy = 'Linear' | |||
beta = Tensor(0.5, mstype.float32) | |||
norm_clip = Tensor(1.0, mstype.float32) | |||
beta_stddev = 0.1 | |||
learning_rate = 0.1 | |||
target_unclipped_quantile = 0.3 | |||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev, | |||
seed=1) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('Liner next norm clip:', next_norm_clip) | |||
decay_policy = 'Geometric' | |||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev, | |||
seed=1) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('Geometric next norm clip:', next_norm_clip) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_ada_clip_gaussian_random_graph(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
decay_policy = 'Linear' | |||
beta = Tensor(0.5, mstype.float32) | |||
norm_clip = Tensor(1.0, mstype.float32) | |||
beta_stddev = 0.1 | |||
learning_rate = 0.1 | |||
target_unclipped_quantile = 0.3 | |||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev, | |||
seed=1) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('Liner next norm clip:', next_norm_clip) | |||
decay_policy = 'Geometric' | |||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev, | |||
seed=1) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('Geometric next norm clip:', next_norm_clip) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_pynative_clip_mech_factory(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
decay_policy = 'Linear' | |||
beta = Tensor(0.5, mstype.float32) | |||
norm_clip = Tensor(1.0, mstype.float32) | |||
beta_stddev = 0.1 | |||
learning_rate = 0.1 | |||
target_unclipped_quantile = 0.3 | |||
clip_mechanism = ClipMechanismsFactory() | |||
ada_clip = clip_mechanism.create('Gaussian', | |||
decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('next_norm_clip: ', next_norm_clip) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_graph_clip_mech_factory(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
decay_policy = 'Linear' | |||
beta = Tensor(0.5, mstype.float32) | |||
norm_clip = Tensor(1.0, mstype.float32) | |||
beta_stddev = 0.1 | |||
learning_rate = 0.1 | |||
target_unclipped_quantile = 0.3 | |||
clip_mechanism = ClipMechanismsFactory() | |||
ada_clip = clip_mechanism.create('Gaussian', | |||
decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('next_norm_clip: ', next_norm_clip) |
@@ -22,7 +22,8 @@ from mindspore import context | |||
import mindspore.dataset as ds | |||
from mindarmour.diff_privacy import DPModel | |||
from mindarmour.diff_privacy import MechanismsFactory | |||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||
from mindarmour.diff_privacy import ClipMechanismsFactory | |||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||
from test_network import LeNet5 | |||
@@ -30,10 +31,12 @@ from test_network import LeNet5 | |||
def dataset_generator(batch_size, batches): | |||
"""mock training data.""" | |||
data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) | |||
label = np.random.randint(0, 10, batches * batch_size).astype(np.int32) | |||
data = np.random.random((batches*batch_size, 1, 32, 32)).astype( | |||
np.float32) | |||
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) | |||
for i in range(batches): | |||
yield data[i * batch_size:(i + 1) * batch_size], label[i * batch_size:(i + 1) * batch_size] | |||
yield data[i*batch_size:(i + 1)*batch_size],\ | |||
label[i*batch_size:(i + 1)*batch_size] | |||
@pytest.mark.level0 | |||
@@ -55,16 +58,24 @@ def test_dp_model_with_pynative_mode(): | |||
factory_opt.set_mechanisms('Gaussian', | |||
norm_bound=norm_clip, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
net_opt = factory_opt.create('Momentum')(network.trainable_params(), | |||
learning_rate=0.1, momentum=0.9) | |||
clip_mech = ClipMechanismsFactory().create('Gaussian', | |||
decay_policy='Linear', | |||
learning_rate=0.01, | |||
target_unclipped_quantile=0.9, | |||
fraction_stddev=0.01) | |||
model = DPModel(micro_batches=micro_batches, | |||
norm_clip=norm_clip, | |||
mech=None, | |||
clip_mech=clip_mech, | |||
noise_mech=None, | |||
network=network, | |||
loss_fn=loss, | |||
optimizer=net_opt, | |||
metrics=None) | |||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||
ms_ds.set_dataset_size(batch_size * batches) | |||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||
['data', 'label']) | |||
ms_ds.set_dataset_size(batch_size*batches) | |||
model.train(epochs, ms_ds, dataset_sink_mode=False) | |||
@@ -82,19 +93,27 @@ def test_dp_model_with_graph_mode(): | |||
batches = 128 | |||
epochs = 1 | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
mech = MechanismsFactory().create('Gaussian', | |||
norm_bound=norm_clip, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
noise_mech = NoiseMechanismsFactory().create('Gaussian', | |||
norm_bound=norm_clip, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
clip_mech = ClipMechanismsFactory().create('Gaussian', | |||
decay_policy='Linear', | |||
learning_rate=0.01, | |||
target_unclipped_quantile=0.9, | |||
fraction_stddev=0.01) | |||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, | |||
momentum=0.9) | |||
model = DPModel(micro_batches=2, | |||
clip_mech=clip_mech, | |||
norm_clip=norm_clip, | |||
mech=mech, | |||
noise_mech=noise_mech, | |||
network=network, | |||
loss_fn=loss, | |||
optimizer=net_opt, | |||
metrics=None) | |||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||
ms_ds.set_dataset_size(batch_size * batches) | |||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||
['data', 'label']) | |||
ms_ds.set_dataset_size(batch_size*batches) | |||
model.train(epochs, ms_ds, dataset_sink_mode=False) | |||
@@ -112,17 +131,25 @@ def test_dp_model_with_graph_mode_ada_gaussian(): | |||
batches = 128 | |||
epochs = 1 | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
mech = MechanismsFactory().create('AdaGaussian', | |||
norm_bound=norm_clip, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
noise_mech = NoiseMechanismsFactory().create('AdaGaussian', | |||
norm_bound=norm_clip, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
clip_mech = ClipMechanismsFactory().create('Gaussian', | |||
decay_policy='Linear', | |||
learning_rate=0.01, | |||
target_unclipped_quantile=0.9, | |||
fraction_stddev=0.01) | |||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, | |||
momentum=0.9) | |||
model = DPModel(micro_batches=2, | |||
clip_mech=clip_mech, | |||
norm_clip=norm_clip, | |||
mech=mech, | |||
noise_mech=noise_mech, | |||
network=network, | |||
loss_fn=loss, | |||
optimizer=net_opt, | |||
metrics=None) | |||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||
ms_ds.set_dataset_size(batch_size * batches) | |||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||
['data', 'label']) | |||
ms_ds.set_dataset_size(batch_size*batches) | |||
model.train(epochs, ms_ds, dataset_sink_mode=False) |