@@ -20,7 +20,7 @@ from easydict import EasyDict as edict | |||||
mnist_cfg = edict({ | mnist_cfg = edict({ | ||||
'num_classes': 10, # the number of classes of model's output | 'num_classes': 10, # the number of classes of model's output | ||||
'lr': 0.1, # the learning rate of model's optimizer | |||||
'lr': 0.01, # the learning rate of model's optimizer | |||||
'momentum': 0.9, # the momentum value of model's optimizer | 'momentum': 0.9, # the momentum value of model's optimizer | ||||
'epoch_size': 10, # training epochs | 'epoch_size': 10, # training epochs | ||||
'batch_size': 256, # batch size for training | 'batch_size': 256, # batch size for training | ||||
@@ -33,8 +33,13 @@ mnist_cfg = edict({ | |||||
'dataset_sink_mode': False, # whether deliver all training data to device one time | 'dataset_sink_mode': False, # whether deliver all training data to device one time | ||||
'micro_batches': 16, # the number of small batches split from an original batch | 'micro_batches': 16, # the number of small batches split from an original batch | ||||
'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters | 'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters | ||||
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training | |||||
'initial_noise_multiplier': 0.5, # the initial multiplication coefficient of the noise added to training | |||||
# parameters' gradients | # parameters' gradients | ||||
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||||
'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||||
'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training | |||||
'clip_decay_policy': 'Linear', # Decay policy of adaptive clipping, decay_policy must be in ['Linear', 'Geometric']. | |||||
'clip_learning_rate': 0.001, # Learning rate of update norm clip. | |||||
'target_unclipped_quantile': 0.9, # Target quantile of norm clip. | |||||
'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction. | |||||
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | 'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | ||||
}) | }) |
@@ -31,7 +31,8 @@ import mindspore.common.dtype as mstype | |||||
from mindarmour.diff_privacy import DPModel | from mindarmour.diff_privacy import DPModel | ||||
from mindarmour.diff_privacy import PrivacyMonitorFactory | from mindarmour.diff_privacy import PrivacyMonitorFactory | ||||
from mindarmour.diff_privacy import MechanismsFactory | |||||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||||
from mindarmour.diff_privacy import ClipMechanismsFactory | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from lenet5_net import LeNet5 | from lenet5_net import LeNet5 | ||||
from lenet5_config import mnist_cfg as cfg | from lenet5_config import mnist_cfg as cfg | ||||
@@ -87,11 +88,14 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
# This configure can run both in pynative mode and graph mode | # This configure can run both in pynative mode and graph mode | ||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | |||||
context.set_context(mode=context.GRAPH_MODE, | |||||
device_target=cfg.device_target) | |||||
network = LeNet5() | network = LeNet5() | ||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, | |||||
reduction="mean") | |||||
config_ck = CheckpointConfig( | |||||
save_checkpoint_steps=cfg.save_checkpoint_steps, | |||||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | ||||
directory='./trained_ckpt_file/', | directory='./trained_ckpt_file/', | ||||
config=config_ck) | config=config_ck) | ||||
@@ -102,17 +106,33 @@ if __name__ == "__main__": | |||||
cfg.epoch_size) | cfg.epoch_size) | ||||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | ||||
raise ValueError("Number of micro_batches should divide evenly batch_size") | |||||
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training. | |||||
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which | |||||
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||||
mech = MechanismsFactory().create(cfg.mechanisms, | |||||
norm_bound=cfg.norm_clip, | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||||
net_opt = nn.Momentum(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||||
# and delta) while training. | |||||
raise ValueError( | |||||
"Number of micro_batches should divide evenly batch_size") | |||||
# Create a factory class of DP noise mechanisms, this method is adding noise | |||||
# in gradients while training. Initial_noise_multiplier is suggested to be | |||||
# greater than 1.0, otherwise the privacy budget would be huge, which means | |||||
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian' | |||||
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' | |||||
# mechanism while be constant with 'Gaussian' mechanism. | |||||
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, | |||||
norm_bound=cfg.norm_clip, | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||||
# Create a factory class of clip mechanisms, this method is to adaptive clip | |||||
# gradients while training, decay_policy support 'Linear' and 'Geometric', | |||||
# learning_rate is the learning rate to update clip_norm, | |||||
# target_unclipped_quantile is the target quantile of norm clip, | |||||
# fraction_stddev is the stddev of Gaussian normal which used in | |||||
# empirical_fraction, the formula is | |||||
# $empirical_fraction + N(0, fraction_stddev)$. | |||||
clip_mech = ClipMechanismsFactory().create(cfg.clip_mechanisms, | |||||
decay_policy=cfg.clip_decay_policy, | |||||
learning_rate=cfg.clip_learning_rate, | |||||
target_unclipped_quantile=cfg.target_unclipped_quantile, | |||||
fraction_stddev=cfg.fraction_stddev) | |||||
net_opt = nn.Momentum(params=network.trainable_params(), | |||||
learning_rate=cfg.lr, momentum=cfg.momentum) | |||||
# Create a monitor for DP training. The function of the monitor is to | |||||
# compute and print the privacy budget(eps and delta) while training. | |||||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | rdp_monitor = PrivacyMonitorFactory.create('rdp', | ||||
num_samples=60000, | num_samples=60000, | ||||
batch_size=cfg.batch_size, | batch_size=cfg.batch_size, | ||||
@@ -121,20 +141,23 @@ if __name__ == "__main__": | |||||
# Create the DP model for training. | # Create the DP model for training. | ||||
model = DPModel(micro_batches=cfg.micro_batches, | model = DPModel(micro_batches=cfg.micro_batches, | ||||
norm_clip=cfg.norm_clip, | norm_clip=cfg.norm_clip, | ||||
mech=mech, | |||||
noise_mech=noise_mech, | |||||
clip_mech=clip_mech, | |||||
network=network, | network=network, | ||||
loss_fn=net_loss, | loss_fn=net_loss, | ||||
optimizer=net_opt, | optimizer=net_opt, | ||||
metrics={"Accuracy": Accuracy()}) | metrics={"Accuracy": Accuracy()}) | ||||
LOGGER.info(TAG, "============== Starting Training ==============") | LOGGER.info(TAG, "============== Starting Training ==============") | ||||
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | |||||
model.train(cfg['epoch_size'], ds_train, | |||||
callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | |||||
dataset_sink_mode=cfg.dataset_sink_mode) | dataset_sink_mode=cfg.dataset_sink_mode) | ||||
LOGGER.info(TAG, "============== Starting Testing ==============") | LOGGER.info(TAG, "============== Starting Testing ==============") | ||||
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | ||||
param_dict = load_checkpoint(ckpt_file_name) | param_dict = load_checkpoint(ckpt_file_name) | ||||
load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) | |||||
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), | |||||
batch_size=cfg.batch_size) | |||||
acc = model.eval(ds_eval, dataset_sink_mode=False) | acc = model.eval(ds_eval, dataset_sink_mode=False) | ||||
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) | LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) |
@@ -1,16 +1,20 @@ | |||||
""" | """ | ||||
This module provide Differential Privacy feature to protect user privacy. | This module provide Differential Privacy feature to protect user privacy. | ||||
""" | """ | ||||
from .mechanisms.mechanisms import GaussianRandom | |||||
from .mechanisms.mechanisms import NoiseGaussianRandom | |||||
from .mechanisms.mechanisms import AdaGaussianRandom | from .mechanisms.mechanisms import AdaGaussianRandom | ||||
from .mechanisms.mechanisms import MechanismsFactory | |||||
from .mechanisms.mechanisms import AdaClippingWithGaussianRandom | |||||
from .mechanisms.mechanisms import NoiseMechanismsFactory | |||||
from .mechanisms.mechanisms import ClipMechanismsFactory | |||||
from .monitor.monitor import PrivacyMonitorFactory | from .monitor.monitor import PrivacyMonitorFactory | ||||
from .optimizer.optimizer import DPOptimizerClassFactory | from .optimizer.optimizer import DPOptimizerClassFactory | ||||
from .train.model import DPModel | from .train.model import DPModel | ||||
__all__ = ['GaussianRandom', | |||||
__all__ = ['NoiseGaussianRandom', | |||||
'AdaGaussianRandom', | 'AdaGaussianRandom', | ||||
'MechanismsFactory', | |||||
'AdaClippingWithGaussianRandom', | |||||
'NoiseMechanismsFactory', | |||||
'ClipMechanismsFactory', | |||||
'PrivacyMonitorFactory', | 'PrivacyMonitorFactory', | ||||
'DPOptimizerClassFactory', | 'DPOptimizerClassFactory', | ||||
'DPModel'] | 'DPModel'] |
@@ -28,11 +28,54 @@ from mindarmour.utils._check_param import check_param_in_range | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
TAG = 'Defense' | |||||
TAG = 'NoiseMechanism' | |||||
class MechanismsFactory: | |||||
""" Factory class of mechanisms""" | |||||
class ClipMechanismsFactory: | |||||
""" Factory class of clip mechanisms""" | |||||
def __init__(self): | |||||
pass | |||||
@staticmethod | |||||
def create(mech_name, *args, **kwargs): | |||||
""" | |||||
Args: | |||||
mech_name(str): Clip noise generated strategy, support 'Gaussian' now. | |||||
args(Union[float, str]): Parameters used for creating clip mechanisms. | |||||
kwargs(Union[float, str]): Parameters used for creating clip | |||||
mechanisms. | |||||
Raises: | |||||
NameError: `mech_name` must be in ['Gaussian']. | |||||
Returns: | |||||
Mechanisms, class of noise generated Mechanism. | |||||
Examples: | |||||
>>> decay_policy = 'Linear' | |||||
>>> beta = Tensor(0.5, mstype.float32) | |||||
>>> norm_clip = Tensor(1.0, mstype.float32) | |||||
>>> beta_stddev = 0.1 | |||||
>>> learning_rate = 0.1 | |||||
>>> target_unclipped_quantile = 0.3 | |||||
>>> clip_mechanism = ClipMechanismsFactory() | |||||
>>> ada_clip = clip_mechanism.create('Gaussian', | |||||
>>> decay_policy=decay_policy, | |||||
>>> learning_rate=learning_rate, | |||||
>>> target_unclipped_quantile=target_unclipped_quantile, | |||||
>>> fraction_stddev=beta_stddev) | |||||
>>> next_norm_clip = ada_clip(beta, norm_clip) | |||||
""" | |||||
if mech_name == 'Gaussian': | |||||
return AdaClippingWithGaussianRandom(*args, **kwargs) | |||||
raise NameError("The {} is not implement, please choose " | |||||
"['Gaussian']".format(mech_name)) | |||||
class NoiseMechanismsFactory: | |||||
""" Factory class of noise mechanisms""" | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -56,42 +99,38 @@ class MechanismsFactory: | |||||
Mechanisms, class of noise generated Mechanism. | Mechanisms, class of noise generated Mechanism. | ||||
Examples: | Examples: | ||||
>>> class Net(nn.Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') | |||||
>>> self.bn = nn.BatchNorm2d(64) | |||||
>>> self.relu = nn.ReLU() | |||||
>>> self.flatten = nn.Flatten() | |||||
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0 | |||||
>>> | |||||
>>> def construct(self, x): | |||||
>>> x = self.conv(x) | |||||
>>> x = self.bn(x) | |||||
>>> x = self.relu(x) | |||||
>>> x = self.flatten(x) | |||||
>>> out = self.fc(x) | |||||
>>> return out | |||||
>>> norm_clip = 1.0 | >>> norm_clip = 1.0 | ||||
>>> initial_noise_multiplier = 1.5 | |||||
>>> net = Net() | |||||
>>> initial_noise_multiplier = 0.01 | |||||
>>> network = LeNet5() | |||||
>>> batch_size = 32 | |||||
>>> batches = 128 | |||||
>>> epochs = 1 | |||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | ||||
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||||
>>> mech = MechanismsFactory().create('Gaussian', | |||||
>>> norm_bound=norm_clip, | |||||
>>> initial_noise_multiplier=initial_noise_multiplier) | |||||
>>> noise_mech = NoiseMechanismsFactory().create('Gaussian', | |||||
>>> norm_bound=norm_clip, | |||||
>>> initial_noise_multiplier=initial_noise_multiplier) | |||||
>>> clip_mech = ClipMechanismsFactory().create('Gaussian', | |||||
>>> decay_policy='Linear', | |||||
>>> learning_rate=0.01, | |||||
>>> target_unclipped_quantile=0.9, | |||||
>>> fraction_stddev=0.01) | |||||
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, | |||||
>>> momentum=0.9) | |||||
>>> model = DPModel(micro_batches=2, | >>> model = DPModel(micro_batches=2, | ||||
>>> norm_clip=1.0, | |||||
>>> mech=mech, | |||||
>>> network=net, | |||||
>>> clip_mech=clip_mech, | |||||
>>> norm_clip=norm_clip, | |||||
>>> noise_mech=noise_mech, | |||||
>>> network=network, | |||||
>>> loss_fn=loss, | >>> loss_fn=loss, | ||||
>>> optimizer=net_opt, | >>> optimizer=net_opt, | ||||
>>> metrics=None) | >>> metrics=None) | ||||
>>> dataset = get_dataset() | |||||
>>> model.train(2, dataset) | |||||
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||||
>>> ['data', 'label']) | |||||
>>> ms_ds.set_dataset_size(batch_size * batches) | |||||
>>> model.train(epochs, ms_ds, dataset_sink_mode=False) | |||||
""" | """ | ||||
if policy == 'Gaussian': | if policy == 'Gaussian': | ||||
return GaussianRandom(*args, **kwargs) | |||||
return NoiseGaussianRandom(*args, **kwargs) | |||||
if policy == 'AdaGaussian': | if policy == 'AdaGaussian': | ||||
return AdaGaussianRandom(*args, **kwargs) | return AdaGaussianRandom(*args, **kwargs) | ||||
raise NameError("The {} is not implement, please choose " | raise NameError("The {} is not implement, please choose " | ||||
@@ -110,7 +149,7 @@ class Mechanisms(Cell): | |||||
""" | """ | ||||
class GaussianRandom(Mechanisms): | |||||
class NoiseGaussianRandom(Mechanisms): | |||||
""" | """ | ||||
Gaussian noise generated mechanism. | Gaussian noise generated mechanism. | ||||
@@ -133,18 +172,21 @@ class GaussianRandom(Mechanisms): | |||||
>>> gradients = Tensor([0.2, 0.9], mstype.float32) | >>> gradients = Tensor([0.2, 0.9], mstype.float32) | ||||
>>> norm_bound = 0.5 | >>> norm_bound = 0.5 | ||||
>>> initial_noise_multiplier = 1.5 | >>> initial_noise_multiplier = 1.5 | ||||
>>> net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||||
>>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) | |||||
>>> res = net(gradients) | >>> res = net(gradients) | ||||
>>> print(res) | >>> print(res) | ||||
""" | """ | ||||
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0, policy=None): | |||||
super(GaussianRandom, self).__init__() | |||||
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0, | |||||
policy=None): | |||||
super(NoiseGaussianRandom, self).__init__() | |||||
self._norm_bound = check_value_positive('norm_bound', norm_bound) | self._norm_bound = check_value_positive('norm_bound', norm_bound) | ||||
self._norm_bound = Tensor(norm_bound, mstype.float32) | self._norm_bound = Tensor(norm_bound, mstype.float32) | ||||
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | |||||
initial_noise_multiplier) | |||||
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||||
self._initial_noise_multiplier = check_value_positive( | |||||
'initial_noise_multiplier', | |||||
initial_noise_multiplier) | |||||
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, | |||||
mstype.float32) | |||||
self._mean = Tensor(0, mstype.float32) | self._mean = Tensor(0, mstype.float32) | ||||
self._normal = P.Normal(seed=seed) | self._normal = P.Normal(seed=seed) | ||||
self._decay_policy = policy | self._decay_policy = policy | ||||
@@ -201,17 +243,20 @@ class AdaGaussianRandom(Mechanisms): | |||||
noise_decay_rate=6e-4, decay_policy='Time', seed=0): | noise_decay_rate=6e-4, decay_policy='Time', seed=0): | ||||
super(AdaGaussianRandom, self).__init__() | super(AdaGaussianRandom, self).__init__() | ||||
norm_bound = check_value_positive('norm_bound', norm_bound) | norm_bound = check_value_positive('norm_bound', norm_bound) | ||||
initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | |||||
initial_noise_multiplier) | |||||
initial_noise_multiplier = check_value_positive( | |||||
'initial_noise_multiplier', | |||||
initial_noise_multiplier) | |||||
self._norm_bound = Tensor(norm_bound, mstype.float32) | self._norm_bound = Tensor(norm_bound, mstype.float32) | ||||
initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||||
initial_noise_multiplier = Tensor(initial_noise_multiplier, | |||||
mstype.float32) | |||||
self._initial_noise_multiplier = Parameter(initial_noise_multiplier, | self._initial_noise_multiplier = Parameter(initial_noise_multiplier, | ||||
name='initial_noise_multiplier') | name='initial_noise_multiplier') | ||||
self._noise_multiplier = Parameter(initial_noise_multiplier, | self._noise_multiplier = Parameter(initial_noise_multiplier, | ||||
name='noise_multiplier') | name='noise_multiplier') | ||||
self._mean = Tensor(0, mstype.float32) | self._mean = Tensor(0, mstype.float32) | ||||
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | |||||
noise_decay_rate = check_param_type('noise_decay_rate', | |||||
noise_decay_rate, float) | |||||
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | ||||
self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32) | self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32) | ||||
if decay_policy not in ['Time', 'Step', 'Exp']: | if decay_policy not in ['Time', 'Step', 'Exp']: | ||||
@@ -232,7 +277,9 @@ class AdaGaussianRandom(Mechanisms): | |||||
Tensor, generated noise with shape like given gradients. | Tensor, generated noise with shape like given gradients. | ||||
""" | """ | ||||
shape = P.Shape()(gradients) | shape = P.Shape()(gradients) | ||||
noise = self._normal(shape, self._mean, self._mul(self._noise_multiplier, self._norm_bound)) | |||||
noise = self._normal(shape, self._mean, | |||||
self._mul(self._noise_multiplier, | |||||
self._norm_bound)) | |||||
return noise | return noise | ||||
@@ -241,10 +288,14 @@ class _MechanismsParamsUpdater(Cell): | |||||
Update mechanisms parameters, the parameters will refresh in train period. | Update mechanisms parameters, the parameters will refresh in train period. | ||||
Args: | Args: | ||||
policy(str): Pass in by the mechanisms class, mechanisms parameters update policy. | |||||
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for controlling the decay size. | |||||
cur_noise_multiplier(Parameter): Pass in by the mechanisms class, current params value in this time. | |||||
init_noise_multiplier(Parameter):Pass in by the mechanisms class, initial params value to be updated. | |||||
policy(str): Pass in by the mechanisms class, mechanisms parameters | |||||
update policy. | |||||
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for | |||||
controlling the decay size. | |||||
cur_noise_multiplier(Parameter): Pass in by the mechanisms class, | |||||
current params value in this time. | |||||
init_noise_multiplier(Parameter):Pass in by the mechanisms class, | |||||
initial params value to be updated. | |||||
Returns: | Returns: | ||||
Tuple, next params value. | Tuple, next params value. | ||||
@@ -281,5 +332,100 @@ class _MechanismsParamsUpdater(Cell): | |||||
next_noise_multiplier = self._assign(self._cur_noise_multiplier, | next_noise_multiplier = self._assign(self._cur_noise_multiplier, | ||||
self._mul(temp, self._cur_noise_multiplier)) | self._mul(temp, self._cur_noise_multiplier)) | ||||
else: | else: | ||||
next_noise_multiplier = self._assign(self._cur_noise_multiplier, self._div(self._one, self._exp(self._one))) | |||||
next_noise_multiplier = self._assign(self._cur_noise_multiplier, | |||||
self._div(self._one, self._exp(self._one))) | |||||
return next_noise_multiplier | return next_noise_multiplier | ||||
class AdaClippingWithGaussianRandom(Cell): | |||||
""" | |||||
Adaptive clipping. If `decay_policy` is 'Linear', the update formula is | |||||
$ norm_clip = norm_clip - learning_rate*(beta-target_unclipped_quantile)$. | |||||
`decay_policy` is 'Geometric', the update formula is | |||||
$ norm_clip = norm_clip*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$. | |||||
where beta is the empirical fraction of samples with the value at most | |||||
`target_unclipped_quantile`. | |||||
Args: | |||||
decay_policy(str): Decay policy of adaptive clipping, decay_policy must | |||||
be in ['Linear', 'Geometric']. Default: Linear. | |||||
learning_rate(float): Learning rate of update norm clip. Default: 0.01. | |||||
target_unclipped_quantile(float): Target quantile of norm clip. Default: 0.9. | |||||
fraction_stddev(float): The stddev of Gaussian normal which used in | |||||
empirical_fraction, the formula is $empirical_fraction + N(0, fraction_stddev)$. | |||||
seed(int): Original random seed, if seed=0 random normal will use secure | |||||
random number. IF seed!=0 random normal will generate values using | |||||
given seed. Default: 0. | |||||
Returns: | |||||
Tensor, undated norm clip . | |||||
Examples: | |||||
>>> decay_policy = 'Linear' | |||||
>>> beta = Tensor(0.5, mstype.float32) | |||||
>>> norm_clip = Tensor(1.0, mstype.float32) | |||||
>>> beta_stddev = 0.01 | |||||
>>> learning_rate = 0.001 | |||||
>>> target_unclipped_quantile = 0.9 | |||||
>>> ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||||
>>> learning_rate=learning_rate, | |||||
>>> target_unclipped_quantile=target_unclipped_quantile, | |||||
>>> fraction_stddev=beta_stddev) | |||||
>>> next_norm_clip = ada_clip(beta, norm_clip) | |||||
""" | |||||
def __init__(self, decay_policy='Linear', learning_rate=0.001, | |||||
target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0): | |||||
super(AdaClippingWithGaussianRandom, self).__init__() | |||||
if decay_policy not in ['Linear', 'Geometric']: | |||||
msg = "decay policy of adaptive clip must be in ['Linear', 'Geometric'], \ | |||||
but got: {}".format(decay_policy) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
self._decay_policy = decay_policy | |||||
learning_rate = check_param_type('learning_rate', learning_rate, float) | |||||
learning_rate = check_value_positive('learning_rate', learning_rate) | |||||
self._learning_rate = Tensor(learning_rate, mstype.float32) | |||||
fraction_stddev = check_param_type('fraction_stddev', fraction_stddev, float) | |||||
self._fraction_stddev = Tensor(fraction_stddev, mstype.float32) | |||||
target_unclipped_quantile = check_param_type('target_unclipped_quantile', | |||||
target_unclipped_quantile, | |||||
float) | |||||
self._target_unclipped_quantile = Tensor(target_unclipped_quantile, | |||||
mstype.float32) | |||||
self._zero = Tensor(0, mstype.float32) | |||||
self._add = P.TensorAdd() | |||||
self._sub = P.Sub() | |||||
self._mul = P.Mul() | |||||
self._exp = P.Exp() | |||||
self._normal = P.Normal(seed=seed) | |||||
def construct(self, empirical_fraction, norm_clip): | |||||
""" | |||||
Update value of norm_clip. | |||||
Args: | |||||
empirical_fraction(Tensor): empirical fraction of samples with the | |||||
value at most `target_unclipped_quantile`. | |||||
norm_clip(Tensor): Clipping bound for the l2 norm of the gradients. | |||||
Returns: | |||||
Tensor, generated noise with shape like given gradients. | |||||
""" | |||||
fraction_noise = self._normal((1,), self._zero, self._fraction_stddev) | |||||
empirical_fraction = self._add(empirical_fraction, fraction_noise) | |||||
if self._decay_policy == 'Linear': | |||||
grad_clip = self._sub(empirical_fraction, | |||||
self._target_unclipped_quantile) | |||||
next_norm_clip = self._sub(norm_clip, | |||||
self._mul(self._learning_rate, grad_clip)) | |||||
# decay_policy == 'Geometric' | |||||
else: | |||||
grad_clip = self._sub(empirical_fraction, | |||||
self._target_unclipped_quantile) | |||||
grad_clip = self._exp(self._mul(-self._learning_rate, grad_clip)) | |||||
next_norm_clip = self._mul(norm_clip, grad_clip) | |||||
return next_norm_clip |
@@ -22,7 +22,7 @@ from mindspore.ops import functional as F | |||||
from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from mindarmour.diff_privacy import MechanismsFactory | |||||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||||
from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater | from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater | ||||
from mindarmour.utils._check_param import check_int_positive | from mindarmour.utils._check_param import check_int_positive | ||||
@@ -70,7 +70,7 @@ class DPOptimizerClassFactory: | |||||
""" | """ | ||||
def __init__(self, micro_batches=2): | def __init__(self, micro_batches=2): | ||||
self._mech_factory = MechanismsFactory() | |||||
self._mech_factory = NoiseMechanismsFactory() | |||||
self.mech = None | self.mech = None | ||||
self._micro_batches = check_int_positive('micro_batches', micro_batches) | self._micro_batches = check_int_positive('micro_batches', micro_batches) | ||||
@@ -48,7 +48,8 @@ from mindspore.nn import Cell | |||||
from mindspore import ParameterTuple | from mindspore import ParameterTuple | ||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater | |||||
from mindarmour.diff_privacy.mechanisms.mechanisms import \ | |||||
_MechanismsParamsUpdater | |||||
from mindarmour.utils._check_param import check_param_type | from mindarmour.utils._check_param import check_param_type | ||||
from mindarmour.utils._check_param import check_value_positive | from mindarmour.utils._check_param import check_value_positive | ||||
from mindarmour.utils._check_param import check_int_positive | from mindarmour.utils._check_param import check_int_positive | ||||
@@ -64,7 +65,7 @@ _reciprocal = P.Reciprocal() | |||||
@_grad_scale.register("Tensor", "Tensor") | @_grad_scale.register("Tensor", "Tensor") | ||||
def tensor_grad_scale(scale, grad): | def tensor_grad_scale(scale, grad): | ||||
""" grad scaling """ | """ grad scaling """ | ||||
return grad * F.cast(_reciprocal(scale), F.dtype(grad)) | |||||
return grad*F.cast(_reciprocal(scale), F.dtype(grad)) | |||||
class DPModel(Model): | class DPModel(Model): | ||||
@@ -72,9 +73,14 @@ class DPModel(Model): | |||||
This class is overload mindspore.train.model.Model. | This class is overload mindspore.train.model.Model. | ||||
Args: | Args: | ||||
micro_batches (int): The number of small batches split from an original batch. Default: 2. | |||||
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | |||||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||||
micro_batches (int): The number of small batches split from an original | |||||
batch. Default: 2. | |||||
norm_clip (float): Use to clip the bound, if set 1, will retun the | |||||
original data. Default: 1.0. | |||||
noise_mech (Mechanisms): The object can generate the different type of | |||||
noise. Default: None. | |||||
clip_mech (Mechanisms): The object is used to update the adaptive clip . | |||||
Default: None. | |||||
Examples: | Examples: | ||||
>>> norm_clip = 1.0 | >>> norm_clip = 1.0 | ||||
@@ -89,63 +95,82 @@ class DPModel(Model): | |||||
>>> factory_opt.set_mechanisms('Gaussian', | >>> factory_opt.set_mechanisms('Gaussian', | ||||
>>> norm_bound=norm_clip, | >>> norm_bound=norm_clip, | ||||
>>> initial_noise_multiplier=initial_noise_multiplier) | >>> initial_noise_multiplier=initial_noise_multiplier) | ||||
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), | |||||
>>> learning_rate=0.1, momentum=0.9) | |||||
>>> clip_mech = ClipMechanismsFactory().create('Gaussian', | |||||
>>> decay_policy='Linear', | |||||
>>> learning_rate=0.01, | |||||
>>> target_unclipped_quantile=0.9, | |||||
>>> fraction_stddev=0.01) | |||||
>>> model = DPModel(micro_batches=micro_batches, | >>> model = DPModel(micro_batches=micro_batches, | ||||
>>> norm_clip=norm_clip, | >>> norm_clip=norm_clip, | ||||
>>> mech=None, | |||||
>>> clip_mech=clip_mech, | |||||
>>> noise_mech=None, | |||||
>>> network=network, | >>> network=network, | ||||
>>> loss_fn=loss, | >>> loss_fn=loss, | ||||
>>> optimizer=net_opt, | >>> optimizer=net_opt, | ||||
>>> metrics=None) | >>> metrics=None) | ||||
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||||
>>> ms_ds.set_dataset_size(batch_size * batches) | |||||
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||||
>>> ['data', 'label']) | |||||
>>> ms_ds.set_dataset_size(batch_size*batches) | |||||
>>> model.train(epochs, ms_ds, dataset_sink_mode=False) | >>> model.train(epochs, ms_ds, dataset_sink_mode=False) | ||||
""" | """ | ||||
def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs): | |||||
def __init__(self, micro_batches=2, norm_clip=1.0, noise_mech=None, | |||||
clip_mech=None, **kwargs): | |||||
if micro_batches: | if micro_batches: | ||||
self._micro_batches = check_int_positive('micro_batches', micro_batches) | |||||
self._micro_batches = check_int_positive('micro_batches', | |||||
micro_batches) | |||||
else: | else: | ||||
self._micro_batches = None | self._micro_batches = None | ||||
norm_clip = check_param_type('norm_clip', norm_clip, float) | norm_clip = check_param_type('norm_clip', norm_clip, float) | ||||
self._norm_clip = check_value_positive('norm_clip', norm_clip) | |||||
if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||||
msg = 'DPOptimizer is not supported while mech is not None' | |||||
norm_clip = check_value_positive('norm_clip', norm_clip) | |||||
norm_clip = Tensor(norm_clip, mstype.float32) | |||||
self._norm_clip = Parameter(norm_clip, 'norm_clip') | |||||
if noise_mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||||
msg = 'DPOptimizer is not supported while noise_mech is not None' | |||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
if mech is None: | |||||
if noise_mech is None: | |||||
if "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | if "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | ||||
if context.get_context('mode') != context.PYNATIVE_MODE: | if context.get_context('mode') != context.PYNATIVE_MODE: | ||||
msg = 'DPOptimizer just support pynative mode currently.' | msg = 'DPOptimizer just support pynative mode currently.' | ||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
else: | else: | ||||
msg = 'DPModel should set mech or DPOptimizer configure, please refer to example.' | |||||
msg = 'DPModel should set noise_mech or DPOptimizer configure, ' \ | |||||
'please refer to example.' | |||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
self._mech = mech | |||||
self._noise_mech = noise_mech | |||||
if clip_mech is not None: | |||||
self._clip_mech = clip_mech | |||||
super(DPModel, self).__init__(**kwargs) | super(DPModel, self).__init__(**kwargs) | ||||
def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): | |||||
def _amp_build_train_network(self, network, optimizer, loss_fn=None, | |||||
level='O0', **kwargs): | |||||
""" | """ | ||||
Build the mixed precision training cell automatically. | Build the mixed precision training cell automatically. | ||||
Args: | Args: | ||||
network (Cell): Definition of the network. | network (Cell): Definition of the network. | ||||
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside. | |||||
Default: None. | |||||
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, | |||||
the `network` should have the loss inside. Default: None. | |||||
optimizer (Optimizer): Optimizer to update the Parameter. | optimizer (Optimizer): Optimizer to update the Parameter. | ||||
level (str): Supports [O0, O2]. Default: "O0". | level (str): Supports [O0, O2]. Default: "O0". | ||||
- O0: Do not change. | - O0: Do not change. | ||||
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32, | |||||
using dynamic loss scale. | |||||
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. | |||||
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. | |||||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. | |||||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else | |||||
scale the loss by LossScaleManager. If set, overwrite the level setting. | |||||
- O2: Cast network to float16, keep batchnorm and `loss_fn` | |||||
(if set) run in float32, using dynamic loss scale. | |||||
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` | |||||
or `mstype.float32`. If set to `mstype.float16`, use `float16` | |||||
mode to train. If set, overwrite the level setting. | |||||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, | |||||
overwrite the level setting. | |||||
loss_scale_manager (Union[None, LossScaleManager]): If None, not | |||||
scale the loss, or else scale the loss by LossScaleManager. | |||||
If set, overwrite the level setting. | |||||
""" | """ | ||||
validator.check_value_type('network', network, nn.Cell, None) | validator.check_value_type('network', network, nn.Cell, None) | ||||
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) | validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) | ||||
@@ -161,9 +186,11 @@ class DPModel(Model): | |||||
_do_keep_batchnorm_fp32(network) | _do_keep_batchnorm_fp32(network) | ||||
if loss_fn: | if loss_fn: | ||||
network = _add_loss_network(network, loss_fn, config.cast_model_type) | |||||
network = _add_loss_network(network, loss_fn, | |||||
config.cast_model_type) | |||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
if _get_parallel_mode() in ( | |||||
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
network = _VirtualDatasetCell(network) | network = _VirtualDatasetCell(network) | ||||
loss_scale = 1.0 | loss_scale = 1.0 | ||||
@@ -173,9 +200,12 @@ class DPModel(Model): | |||||
update_cell = loss_scale_manager.get_update_cell() | update_cell = loss_scale_manager.get_update_cell() | ||||
if update_cell is not None: | if update_cell is not None: | ||||
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow. | # only cpu not support `TrainOneStepWithLossScaleCell` for control flow. | ||||
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU": | |||||
msg = "Only `loss_scale_manager=None` and `loss_scale_manager=FixedLossScaleManager(drop_overflow" \ | |||||
"_update=False)` are supported in current version. If you use `O2` option, please use " \ | |||||
if not context.get_context("enable_ge") and context.get_context( | |||||
"device_target") == "CPU": | |||||
msg = "Only `loss_scale_manager=None` and " \ | |||||
"`loss_scale_manager=FixedLossScaleManager(drop_overflow" \ | |||||
"_update=False)` are supported in current version. " \ | |||||
"If you use `O2` option, please use " \ | |||||
"`loss_scale_manager=None` or `FixedLossScaleManager`" | "`loss_scale_manager=None` or `FixedLossScaleManager`" | ||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
@@ -184,15 +214,17 @@ class DPModel(Model): | |||||
scale_update_cell=update_cell, | scale_update_cell=update_cell, | ||||
micro_batches=self._micro_batches, | micro_batches=self._micro_batches, | ||||
norm_clip=self._norm_clip, | norm_clip=self._norm_clip, | ||||
mech=self._mech).set_train() | |||||
clip_mech=self._clip_mech, | |||||
noise_mech=self._noise_mech).set_train() | |||||
return network | return network | ||||
network = _TrainOneStepCell(network, | network = _TrainOneStepCell(network, | ||||
optimizer, | optimizer, | ||||
self._norm_clip, | |||||
loss_scale, | loss_scale, | ||||
micro_batches=self._micro_batches, | micro_batches=self._micro_batches, | ||||
norm_clip=self._norm_clip, | |||||
mech=self._mech).set_train() | |||||
clip_mech=self._clip_mech, | |||||
noise_mech=self._noise_mech).set_train() | |||||
return network | return network | ||||
def _build_train_network(self): | def _build_train_network(self): | ||||
@@ -233,7 +265,8 @@ class DPModel(Model): | |||||
elif self._loss_fn: | elif self._loss_fn: | ||||
network = nn.WithLossCell(network, self._loss_fn) | network = nn.WithLossCell(network, self._loss_fn) | ||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, | |||||
ParallelMode.AUTO_PARALLEL): | |||||
network.set_auto_parallel() | network.set_auto_parallel() | ||||
return network | return network | ||||
@@ -267,11 +300,10 @@ class _ClipGradients(nn.Cell): | |||||
new_grads = () | new_grads = () | ||||
for grad in grads: | for grad in grads: | ||||
if clip_type == 0: | if clip_type == 0: | ||||
t = C.clip_by_value(grad, F.tuple_to_array((-clip_value,)), | |||||
F.tuple_to_array((clip_value,))) | |||||
norm = C.clip_by_value(grad, -clip_value, clip_value) | |||||
else: | else: | ||||
t = self.clip_by_norm(grad, F.tuple_to_array((clip_value,))) | |||||
new_grads = new_grads + (t,) | |||||
norm = self.clip_by_norm(grad, clip_value) | |||||
new_grads = new_grads + (norm,) | |||||
return new_grads | return new_grads | ||||
@@ -292,20 +324,27 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
r""" | r""" | ||||
Network training with loss scaling. | Network training with loss scaling. | ||||
This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update | |||||
Cell as args. The loss scale value can be updated in both host side or device side. The | |||||
TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input | |||||
data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should | |||||
be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`. | |||||
If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored. | |||||
This is a training step with loss scaling. It takes a network, an optimizer | |||||
and possibly a scale update Cell as args. The loss scale value can be | |||||
updated in both host side or device side. The TrainOneStepWithLossScaleCell | |||||
will be compiled to be graph which takes `data`, `label`, `sens` as input | |||||
data. The `sens` is acting as loss scaling value. If you want to update it | |||||
on host side, the value should be provided. If `sens` is not given, the loss | |||||
scale update logic should be provied by `scale_update_cell`. If | |||||
`scale_update_cell` is not None and `sens` is provided, the | |||||
`scale_update_cell` will be ignored. | |||||
Args: | Args: | ||||
network (Cell): The training network. | network (Cell): The training network. | ||||
optimizer (Cell): Optimizer for updating the weights. | optimizer (Cell): Optimizer for updating the weights. | ||||
scale_update_cell(Cell): The loss scaling update logic cell. Default: None. | |||||
micro_batches (int): The number of small batches split from an original batch. Default: None. | |||||
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. | |||||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||||
scale_update_cell(Cell): The loss scaling update logic cell. | |||||
Default: None. | |||||
micro_batches (int): The number of small batches split from an original | |||||
batch. Default: None. | |||||
norm_clip (Tensor): Use to clip the bound, if set 1, will return the | |||||
original data. Default: 1.0. | |||||
noise_mech (Mechanisms): The object can generate the different type of | |||||
noise. Default: None. | |||||
Inputs: | Inputs: | ||||
- **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | - **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | ||||
@@ -320,7 +359,9 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
- **loss_scale** (Tensor) - Tensor with shape :math:`()`. | - **loss_scale** (Tensor) - Tensor with shape :math:`()`. | ||||
""" | """ | ||||
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, norm_clip=1.0, mech=None): | |||||
def __init__(self, network, optimizer, scale_update_cell=None, | |||||
micro_batches=None, norm_clip=1.0, noise_mech=None, | |||||
clip_mech=None): | |||||
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | ||||
self.network = network | self.network = network | ||||
self.network.set_grad() | self.network.set_grad() | ||||
@@ -346,39 +387,54 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
self.allreduce = P.AllReduce() | self.allreduce = P.AllReduce() | ||||
self.parallel_mode = _get_parallel_mode() | self.parallel_mode = _get_parallel_mode() | ||||
self.grad_reducer = F.identity | self.grad_reducer = F.identity | ||||
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] | |||||
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, | |||||
ParallelMode.HYBRID_PARALLEL] | |||||
if self.reducer_flag: | if self.reducer_flag: | ||||
mean = _get_mirror_mean() | mean = _get_mirror_mean() | ||||
degree = _get_device_num() | degree = _get_device_num() | ||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, | |||||
mean, degree) | |||||
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | ||||
self.loss_scale = None | self.loss_scale = None | ||||
self.loss_scaling_manager = scale_update_cell | self.loss_scaling_manager = scale_update_cell | ||||
if scale_update_cell: | if scale_update_cell: | ||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||||
name="loss_scale") | |||||
self.loss_scale = Parameter( | |||||
Tensor(scale_update_cell.get_loss_scale(), | |||||
dtype=mstype.float32), | |||||
name="loss_scale") | |||||
self.add_flags(has_effect=True) | self.add_flags(has_effect=True) | ||||
# dp params | # dp params | ||||
self._micro_batches = micro_batches | self._micro_batches = micro_batches | ||||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||||
self._l2_norm = check_value_positive('norm_clip', norm_clip) | |||||
self._norm_clip = norm_clip | |||||
self._split = P.Split(0, self._micro_batches) | self._split = P.Split(0, self._micro_batches) | ||||
self._clip_by_global_norm = _ClipGradients() | self._clip_by_global_norm = _ClipGradients() | ||||
self._mech = mech | |||||
self._noise_mech = noise_mech | |||||
self._clip_mech = clip_mech | |||||
self._add = P.TensorAdd() | |||||
self._norm = nn.Norm() | |||||
self._tuple_add = _TupleAdd() | self._tuple_add = _TupleAdd() | ||||
self._hyper_map = C.HyperMap() | self._hyper_map = C.HyperMap() | ||||
self._micro_float = Tensor(micro_batches, mstype.float32) | self._micro_float = Tensor(micro_batches, mstype.float32) | ||||
self._mech_param_updater = None | |||||
if self._mech is not None and self._mech._decay_policy is not None: | |||||
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy, | |||||
decay_rate=self._mech._noise_decay_rate, | |||||
cur_noise_multiplier= | |||||
self._mech._noise_multiplier, | |||||
init_noise_multiplier= | |||||
self._mech._initial_noise_multiplier) | |||||
self._zero = Tensor(0, mstype.float32) | |||||
self._assign = P.Assign() | |||||
self._div = P.Div() | |||||
self._sqrt = P.Sqrt() | |||||
self._reduce_sum = P.ReduceSum() | |||||
self._square_all = P.Square() | |||||
self._less = P.Less() | |||||
self._cast = P.Cast() | |||||
self._noise_mech_param_updater = None | |||||
if self._noise_mech is not None and self._noise_mech._decay_policy is not None: | |||||
self._noise_mech_param_updater = _MechanismsParamsUpdater( | |||||
policy=self._noise_mech._decay_policy, | |||||
decay_rate=self._noise_mech._noise_decay_rate, | |||||
cur_noise_multiplier= | |||||
self._noise_mech._noise_multiplier, | |||||
init_noise_multiplier= | |||||
self._noise_mech._initial_noise_multiplier) | |||||
def construct(self, data, label, sens=None): | def construct(self, data, label, sens=None): | ||||
""" | """ | ||||
@@ -402,30 +458,62 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
record_labels = self._split(label) | record_labels = self._split(label) | ||||
# first index | # first index | ||||
loss = self.network(record_datas[0], record_labels[0]) | loss = self.network(record_datas[0], record_labels[0]) | ||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) | |||||
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) | |||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||||
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, | |||||
F.dtype(loss)) | |||||
record_grad = self.grad(self.network, weights)(record_datas[0], | |||||
record_labels[0], | |||||
scaling_sens_filled) | |||||
beta = self._zero | |||||
square_sum = self._zero | |||||
for grad in record_grad: | |||||
square_sum = self._add(square_sum, | |||||
self._reduce_sum(self._square_all(grad))) | |||||
norm_grad = self._sqrt(square_sum) | |||||
beta = self._add(beta, | |||||
self._cast(self._less(norm_grad, self._norm_clip), | |||||
mstype.float32)) | |||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, | |||||
self._norm_clip) | |||||
grads = record_grad | grads = record_grad | ||||
total_loss = loss | total_loss = loss | ||||
for i in range(1, self._micro_batches): | for i in range(1, self._micro_batches): | ||||
loss = self.network(record_datas[i], record_labels[i]) | loss = self.network(record_datas[i], record_labels[i]) | ||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) | |||||
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) | |||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||||
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, | |||||
F.dtype(loss)) | |||||
record_grad = self.grad(self.network, weights)(record_datas[i], | |||||
record_labels[i], | |||||
scaling_sens_filled) | |||||
square_sum = self._zero | |||||
for grad in record_grad: | |||||
square_sum = self._add(square_sum, | |||||
self._reduce_sum(self._square_all(grad))) | |||||
norm_grad = self._sqrt(square_sum) | |||||
beta = self._add(beta, | |||||
self._cast(self._less(norm_grad, self._norm_clip), | |||||
mstype.float32)) | |||||
record_grad = self._clip_by_global_norm(record_grad, | |||||
GRADIENT_CLIP_TYPE, | |||||
self._norm_clip) | |||||
grads = self._tuple_add(grads, record_grad) | grads = self._tuple_add(grads, record_grad) | ||||
total_loss = P.TensorAdd()(total_loss, loss) | total_loss = P.TensorAdd()(total_loss, loss) | ||||
loss = P.Div()(total_loss, self._micro_float) | loss = P.Div()(total_loss, self._micro_float) | ||||
beta = self._div(beta, self._micro_batches) | |||||
if self._mech is not None: | |||||
if self._noise_mech is not None: | |||||
grad_noise_tuple = () | grad_noise_tuple = () | ||||
for grad_item in grads: | for grad_item in grads: | ||||
grad_noise = self._mech(grad_item) | grad_noise = self._mech(grad_item) | ||||
grad_noise_tuple = grad_noise_tuple + (grad_noise,) | grad_noise_tuple = grad_noise_tuple + (grad_noise,) | ||||
grads = self._tuple_add(grads, grad_noise_tuple) | grads = self._tuple_add(grads, grad_noise_tuple) | ||||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), | |||||
grads) | |||||
# update mech parameters | # update mech parameters | ||||
if self._mech_param_updater is not None: | |||||
multiplier = self._mech_param_updater() | |||||
if self._noise_mech_param_updater is not None: | |||||
multiplier = self._noise_mech_param_updater() | |||||
loss = F.depend(loss, multiplier) | loss = F.depend(loss, multiplier) | ||||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) | grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) | ||||
@@ -456,6 +544,10 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
else: | else: | ||||
opt = self.optimizer(grads) | opt = self.optimizer(grads) | ||||
ret = (loss, cond, scaling_sens) | ret = (loss, cond, scaling_sens) | ||||
if self._clip_mech is not None: | |||||
next_norm_clip = self._clip_mech(beta, self._norm_clip) | |||||
P.assign(self._norm_clip, next_norm_clip) | |||||
return F.depend(ret, opt) | return F.depend(ret, opt) | ||||
@@ -463,17 +555,22 @@ class _TrainOneStepCell(Cell): | |||||
r""" | r""" | ||||
Network training package class. | Network training package class. | ||||
Wraps the network with an optimizer. The resulting Cell be trained with input data and label. | |||||
Backward graph will be created in the construct function to do parameter updating. Different | |||||
parallel modes are available to run the training. | |||||
Wraps the network with an optimizer. The resulting Cell be trained with | |||||
input data and label. Backward graph will be created in the construct | |||||
function to do parameter updating. Different parallel modes are available | |||||
to run the training. | |||||
Args: | Args: | ||||
network (Cell): The training network. | network (Cell): The training network. | ||||
optimizer (Cell): Optimizer for updating the weights. | optimizer (Cell): Optimizer for updating the weights. | ||||
sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0. | |||||
micro_batches (int): The number of small batches split from an original batch. Default: None. | |||||
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. | |||||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||||
sens (Number): The scaling number to be filled as the input of back | |||||
propagation. Default value is 1.0. | |||||
micro_batches (int): The number of small batches split from an original | |||||
batch. Default: None. | |||||
norm_clip (Tensor): Use to clip the bound, if set 1, will return the | |||||
original data. Default: 1.0. | |||||
noise_mech (Mechanisms): The object can generate the different type | |||||
of noise. Default: None. | |||||
Inputs: | Inputs: | ||||
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | ||||
@@ -483,7 +580,9 @@ class _TrainOneStepCell(Cell): | |||||
Tensor, a scalar Tensor with shape :math:`()`. | Tensor, a scalar Tensor with shape :math:`()`. | ||||
""" | """ | ||||
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, norm_clip=1.0, mech=None): | |||||
def __init__(self, network, optimizer, norm_clip=1.0, sens=1.0, | |||||
micro_batches=None, | |||||
noise_mech=None, clip_mech=None): | |||||
super(_TrainOneStepCell, self).__init__(auto_prefix=False) | super(_TrainOneStepCell, self).__init__(auto_prefix=False) | ||||
self.network = network | self.network = network | ||||
self.network.set_grad() | self.network.set_grad() | ||||
@@ -495,36 +594,51 @@ class _TrainOneStepCell(Cell): | |||||
self.reducer_flag = False | self.reducer_flag = False | ||||
self.grad_reducer = None | self.grad_reducer = None | ||||
parallel_mode = _get_parallel_mode() | parallel_mode = _get_parallel_mode() | ||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||||
if parallel_mode in ( | |||||
ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||||
self.reducer_flag = True | self.reducer_flag = True | ||||
if self.reducer_flag: | if self.reducer_flag: | ||||
mean = _get_mirror_mean() | mean = _get_mirror_mean() | ||||
degree = _get_device_num() | degree = _get_device_num() | ||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, | |||||
mean, degree) | |||||
# dp params | # dp params | ||||
if micro_batches is None: | if micro_batches is None: | ||||
msg = 'micro_batches must give in differential privacy, but got value: {}'.format(micro_batches) | |||||
msg = 'micro_batches must give in differential privacy, but got value: {}'.format( | |||||
micro_batches) | |||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
self._micro_batches = micro_batches | self._micro_batches = micro_batches | ||||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||||
self._l2_norm = check_value_positive('norm_clip', norm_clip) | |||||
self._norm_clip = norm_clip | |||||
self._split = P.Split(0, self._micro_batches) | self._split = P.Split(0, self._micro_batches) | ||||
self._clip_by_global_norm = _ClipGradients() | self._clip_by_global_norm = _ClipGradients() | ||||
self._mech = mech | |||||
self._noise_mech = noise_mech | |||||
self._clip_mech = clip_mech | |||||
self._tuple_add = _TupleAdd() | self._tuple_add = _TupleAdd() | ||||
self._add = P.TensorAdd() | |||||
self._norm = nn.Norm() | |||||
self._hyper_map = C.HyperMap() | self._hyper_map = C.HyperMap() | ||||
self._zero = Tensor(0, mstype.float32) | |||||
self._assign = P.Assign() | |||||
self._div = P.Div() | |||||
self._sqrt = P.Sqrt() | |||||
self._reduce_sum = P.ReduceSum() | |||||
self._square_all = P.Square() | |||||
self._less = P.Less() | |||||
self._cast = P.Cast() | |||||
self._micro_float = Tensor(micro_batches, mstype.float32) | self._micro_float = Tensor(micro_batches, mstype.float32) | ||||
self._mech_param_updater = None | |||||
if self._mech is not None and self._mech._decay_policy is not None: | |||||
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy, | |||||
decay_rate=self._mech._noise_decay_rate, | |||||
cur_noise_multiplier= | |||||
self._mech._noise_multiplier, | |||||
init_noise_multiplier= | |||||
self._mech._initial_noise_multiplier) | |||||
self._noise_mech_param_updater = None | |||||
if self._noise_mech is not None and self._noise_mech._decay_policy is not None: | |||||
self._noise_mech_param_updater = _MechanismsParamsUpdater( | |||||
policy=self._noise_mech._decay_policy, | |||||
decay_rate=self._noise_mech._noise_decay_rate, | |||||
cur_noise_multiplier= | |||||
self._noise_mech._noise_multiplier, | |||||
init_noise_multiplier= | |||||
self._noise_mech._initial_noise_multiplier) | |||||
def construct(self, data, label): | def construct(self, data, label): | ||||
""" | """ | ||||
@@ -535,32 +649,65 @@ class _TrainOneStepCell(Cell): | |||||
record_labels = self._split(label) | record_labels = self._split(label) | ||||
loss = self.network(record_datas[0], record_labels[0]) | loss = self.network(record_datas[0], record_labels[0]) | ||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | ||||
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) | |||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||||
record_grad = self.grad(self.network, weights)(record_datas[0], | |||||
record_labels[0], sens) | |||||
beta = self._zero | |||||
square_sum = self._zero | |||||
for grad in record_grad: | |||||
square_sum = self._add(square_sum, | |||||
self._reduce_sum(self._square_all(grad))) | |||||
norm_grad = self._sqrt(square_sum) | |||||
beta = self._add(beta, | |||||
self._cast(self._less(norm_grad, self._norm_clip), | |||||
mstype.float32)) | |||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, | |||||
self._norm_clip) | |||||
grads = record_grad | grads = record_grad | ||||
total_loss = loss | total_loss = loss | ||||
for i in range(1, self._micro_batches): | for i in range(1, self._micro_batches): | ||||
loss = self.network(record_datas[i], record_labels[i]) | loss = self.network(record_datas[i], record_labels[i]) | ||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | ||||
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) | |||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||||
record_grad = self.grad(self.network, weights)(record_datas[i], | |||||
record_labels[i], | |||||
sens) | |||||
square_sum = self._zero | |||||
for grad in record_grad: | |||||
square_sum = self._add(square_sum, | |||||
self._reduce_sum(self._square_all(grad))) | |||||
norm_grad = self._sqrt(square_sum) | |||||
beta = self._add(beta, | |||||
self._cast(self._less(norm_grad, self._norm_clip), | |||||
mstype.float32)) | |||||
record_grad = self._clip_by_global_norm(record_grad, | |||||
GRADIENT_CLIP_TYPE, | |||||
self._norm_clip) | |||||
grads = self._tuple_add(grads, record_grad) | grads = self._tuple_add(grads, record_grad) | ||||
total_loss = P.TensorAdd()(total_loss, loss) | total_loss = P.TensorAdd()(total_loss, loss) | ||||
loss = P.Div()(total_loss, self._micro_float) | |||||
loss = self._div(total_loss, self._micro_float) | |||||
beta = self._div(beta, self._micro_batches) | |||||
if self._mech is not None: | |||||
if self._noise_mech is not None: | |||||
grad_noise_tuple = () | grad_noise_tuple = () | ||||
for grad_item in grads: | for grad_item in grads: | ||||
grad_noise = self._mech(grad_item) | |||||
grad_noise = self._noise_mech(grad_item) | |||||
grad_noise_tuple = grad_noise_tuple + (grad_noise,) | grad_noise_tuple = grad_noise_tuple + (grad_noise,) | ||||
grads = self._tuple_add(grads, grad_noise_tuple) | grads = self._tuple_add(grads, grad_noise_tuple) | ||||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), | |||||
grads) | |||||
# update mech parameters | # update mech parameters | ||||
if self._mech_param_updater is not None: | |||||
multiplier = self._mech_param_updater() | |||||
if self._noise_mech_param_updater is not None: | |||||
multiplier = self._noise_mech_param_updater() | |||||
loss = F.depend(loss, multiplier) | loss = F.depend(loss, multiplier) | ||||
if self.reducer_flag: | if self.reducer_flag: | ||||
# apply grad reducer on grads | # apply grad reducer on grads | ||||
grads = self.grad_reducer(grads) | grads = self.grad_reducer(grads) | ||||
if self._clip_mech is not None: | |||||
next_norm_clip = self._clip_mech(beta, self._norm_clip) | |||||
self._norm_clip = self._assign(self._norm_clip, next_norm_clip) | |||||
loss = F.depend(loss, next_norm_clip) | |||||
return F.depend(loss, self.optimizer(grads)) | return F.depend(loss, self.optimizer(grads)) |
@@ -19,9 +19,11 @@ import pytest | |||||
from mindspore import context | from mindspore import context | ||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
from mindarmour.diff_privacy import GaussianRandom | |||||
from mindarmour.diff_privacy import NoiseGaussianRandom | |||||
from mindarmour.diff_privacy import AdaGaussianRandom | from mindarmour.diff_privacy import AdaGaussianRandom | ||||
from mindarmour.diff_privacy import MechanismsFactory | |||||
from mindarmour.diff_privacy import AdaClippingWithGaussianRandom | |||||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||||
from mindarmour.diff_privacy import ClipMechanismsFactory | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@@ -33,7 +35,7 @@ def test_graph_gaussian(): | |||||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | ||||
norm_bound = 1.0 | norm_bound = 1.0 | ||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||||
net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) | |||||
res = net(grad) | res = net(grad) | ||||
print(res) | print(res) | ||||
@@ -47,7 +49,7 @@ def test_pynative_gaussian(): | |||||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | ||||
norm_bound = 1.0 | norm_bound = 1.0 | ||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||||
net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) | |||||
res = net(grad) | res = net(grad) | ||||
print(res) | print(res) | ||||
@@ -80,13 +82,13 @@ def test_graph_factory(): | |||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
alpha = 0.5 | alpha = 0.5 | ||||
decay_policy = 'Step' | decay_policy = 'Step' | ||||
noise_mechanism = MechanismsFactory() | |||||
noise_mechanism = NoiseMechanismsFactory() | |||||
noise_construct = noise_mechanism.create('Gaussian', | noise_construct = noise_mechanism.create('Gaussian', | ||||
norm_bound, | norm_bound, | ||||
initial_noise_multiplier) | initial_noise_multiplier) | ||||
noise = noise_construct(grad) | noise = noise_construct(grad) | ||||
print('Gaussian noise: ', noise) | print('Gaussian noise: ', noise) | ||||
ada_mechanism = MechanismsFactory() | |||||
ada_mechanism = NoiseMechanismsFactory() | |||||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | ada_noise_construct = ada_mechanism.create('AdaGaussian', | ||||
norm_bound, | norm_bound, | ||||
initial_noise_multiplier, | initial_noise_multiplier, | ||||
@@ -124,13 +126,13 @@ def test_pynative_factory(): | |||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
alpha = 0.5 | alpha = 0.5 | ||||
decay_policy = 'Step' | decay_policy = 'Step' | ||||
noise_mechanism = MechanismsFactory() | |||||
noise_mechanism = NoiseMechanismsFactory() | |||||
noise_construct = noise_mechanism.create('Gaussian', | noise_construct = noise_mechanism.create('Gaussian', | ||||
norm_bound, | norm_bound, | ||||
initial_noise_multiplier) | initial_noise_multiplier) | ||||
noise = noise_construct(grad) | noise = noise_construct(grad) | ||||
print('Gaussian noise: ', noise) | print('Gaussian noise: ', noise) | ||||
ada_mechanism = MechanismsFactory() | |||||
ada_mechanism = NoiseMechanismsFactory() | |||||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | ada_noise_construct = ada_mechanism.create('AdaGaussian', | ||||
norm_bound, | norm_bound, | ||||
initial_noise_multiplier, | initial_noise_multiplier, | ||||
@@ -151,7 +153,7 @@ def test_pynative_exponential(): | |||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
alpha = 0.5 | alpha = 0.5 | ||||
decay_policy = 'Exp' | decay_policy = 'Exp' | ||||
ada_mechanism = MechanismsFactory() | |||||
ada_mechanism = NoiseMechanismsFactory() | |||||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | ada_noise_construct = ada_mechanism.create('AdaGaussian', | ||||
norm_bound, | norm_bound, | ||||
initial_noise_multiplier, | initial_noise_multiplier, | ||||
@@ -172,7 +174,7 @@ def test_graph_exponential(): | |||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
alpha = 0.5 | alpha = 0.5 | ||||
decay_policy = 'Exp' | decay_policy = 'Exp' | ||||
ada_mechanism = MechanismsFactory() | |||||
ada_mechanism = NoiseMechanismsFactory() | |||||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | ada_noise_construct = ada_mechanism.create('AdaGaussian', | ||||
norm_bound, | norm_bound, | ||||
initial_noise_multiplier, | initial_noise_multiplier, | ||||
@@ -180,3 +182,107 @@ def test_graph_exponential(): | |||||
decay_policy=decay_policy) | decay_policy=decay_policy) | ||||
ada_noise = ada_noise_construct(grad) | ada_noise = ada_noise_construct(grad) | ||||
print('ada noise: ', ada_noise) | print('ada noise: ', ada_noise) | ||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_ada_clip_gaussian_random_pynative(): | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
decay_policy = 'Linear' | |||||
beta = Tensor(0.5, mstype.float32) | |||||
norm_clip = Tensor(1.0, mstype.float32) | |||||
beta_stddev = 0.1 | |||||
learning_rate = 0.1 | |||||
target_unclipped_quantile = 0.3 | |||||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||||
learning_rate=learning_rate, | |||||
target_unclipped_quantile=target_unclipped_quantile, | |||||
fraction_stddev=beta_stddev, | |||||
seed=1) | |||||
next_norm_clip = ada_clip(beta, norm_clip) | |||||
print('Liner next norm clip:', next_norm_clip) | |||||
decay_policy = 'Geometric' | |||||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||||
learning_rate=learning_rate, | |||||
target_unclipped_quantile=target_unclipped_quantile, | |||||
fraction_stddev=beta_stddev, | |||||
seed=1) | |||||
next_norm_clip = ada_clip(beta, norm_clip) | |||||
print('Geometric next norm clip:', next_norm_clip) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_ada_clip_gaussian_random_graph(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
decay_policy = 'Linear' | |||||
beta = Tensor(0.5, mstype.float32) | |||||
norm_clip = Tensor(1.0, mstype.float32) | |||||
beta_stddev = 0.1 | |||||
learning_rate = 0.1 | |||||
target_unclipped_quantile = 0.3 | |||||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||||
learning_rate=learning_rate, | |||||
target_unclipped_quantile=target_unclipped_quantile, | |||||
fraction_stddev=beta_stddev, | |||||
seed=1) | |||||
next_norm_clip = ada_clip(beta, norm_clip) | |||||
print('Liner next norm clip:', next_norm_clip) | |||||
decay_policy = 'Geometric' | |||||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||||
learning_rate=learning_rate, | |||||
target_unclipped_quantile=target_unclipped_quantile, | |||||
fraction_stddev=beta_stddev, | |||||
seed=1) | |||||
next_norm_clip = ada_clip(beta, norm_clip) | |||||
print('Geometric next norm clip:', next_norm_clip) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_pynative_clip_mech_factory(): | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
decay_policy = 'Linear' | |||||
beta = Tensor(0.5, mstype.float32) | |||||
norm_clip = Tensor(1.0, mstype.float32) | |||||
beta_stddev = 0.1 | |||||
learning_rate = 0.1 | |||||
target_unclipped_quantile = 0.3 | |||||
clip_mechanism = ClipMechanismsFactory() | |||||
ada_clip = clip_mechanism.create('Gaussian', | |||||
decay_policy=decay_policy, | |||||
learning_rate=learning_rate, | |||||
target_unclipped_quantile=target_unclipped_quantile, | |||||
fraction_stddev=beta_stddev) | |||||
next_norm_clip = ada_clip(beta, norm_clip) | |||||
print('next_norm_clip: ', next_norm_clip) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_graph_clip_mech_factory(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
decay_policy = 'Linear' | |||||
beta = Tensor(0.5, mstype.float32) | |||||
norm_clip = Tensor(1.0, mstype.float32) | |||||
beta_stddev = 0.1 | |||||
learning_rate = 0.1 | |||||
target_unclipped_quantile = 0.3 | |||||
clip_mechanism = ClipMechanismsFactory() | |||||
ada_clip = clip_mechanism.create('Gaussian', | |||||
decay_policy=decay_policy, | |||||
learning_rate=learning_rate, | |||||
target_unclipped_quantile=target_unclipped_quantile, | |||||
fraction_stddev=beta_stddev) | |||||
next_norm_clip = ada_clip(beta, norm_clip) | |||||
print('next_norm_clip: ', next_norm_clip) |
@@ -22,7 +22,8 @@ from mindspore import context | |||||
import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
from mindarmour.diff_privacy import DPModel | from mindarmour.diff_privacy import DPModel | ||||
from mindarmour.diff_privacy import MechanismsFactory | |||||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||||
from mindarmour.diff_privacy import ClipMechanismsFactory | |||||
from mindarmour.diff_privacy import DPOptimizerClassFactory | from mindarmour.diff_privacy import DPOptimizerClassFactory | ||||
from test_network import LeNet5 | from test_network import LeNet5 | ||||
@@ -30,10 +31,12 @@ from test_network import LeNet5 | |||||
def dataset_generator(batch_size, batches): | def dataset_generator(batch_size, batches): | ||||
"""mock training data.""" | """mock training data.""" | ||||
data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) | |||||
label = np.random.randint(0, 10, batches * batch_size).astype(np.int32) | |||||
data = np.random.random((batches*batch_size, 1, 32, 32)).astype( | |||||
np.float32) | |||||
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) | |||||
for i in range(batches): | for i in range(batches): | ||||
yield data[i * batch_size:(i + 1) * batch_size], label[i * batch_size:(i + 1) * batch_size] | |||||
yield data[i*batch_size:(i + 1)*batch_size],\ | |||||
label[i*batch_size:(i + 1)*batch_size] | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@@ -55,16 +58,24 @@ def test_dp_model_with_pynative_mode(): | |||||
factory_opt.set_mechanisms('Gaussian', | factory_opt.set_mechanisms('Gaussian', | ||||
norm_bound=norm_clip, | norm_bound=norm_clip, | ||||
initial_noise_multiplier=initial_noise_multiplier) | initial_noise_multiplier=initial_noise_multiplier) | ||||
net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
net_opt = factory_opt.create('Momentum')(network.trainable_params(), | |||||
learning_rate=0.1, momentum=0.9) | |||||
clip_mech = ClipMechanismsFactory().create('Gaussian', | |||||
decay_policy='Linear', | |||||
learning_rate=0.01, | |||||
target_unclipped_quantile=0.9, | |||||
fraction_stddev=0.01) | |||||
model = DPModel(micro_batches=micro_batches, | model = DPModel(micro_batches=micro_batches, | ||||
norm_clip=norm_clip, | norm_clip=norm_clip, | ||||
mech=None, | |||||
clip_mech=clip_mech, | |||||
noise_mech=None, | |||||
network=network, | network=network, | ||||
loss_fn=loss, | loss_fn=loss, | ||||
optimizer=net_opt, | optimizer=net_opt, | ||||
metrics=None) | metrics=None) | ||||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||||
ms_ds.set_dataset_size(batch_size * batches) | |||||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||||
['data', 'label']) | |||||
ms_ds.set_dataset_size(batch_size*batches) | |||||
model.train(epochs, ms_ds, dataset_sink_mode=False) | model.train(epochs, ms_ds, dataset_sink_mode=False) | ||||
@@ -82,19 +93,27 @@ def test_dp_model_with_graph_mode(): | |||||
batches = 128 | batches = 128 | ||||
epochs = 1 | epochs = 1 | ||||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | ||||
mech = MechanismsFactory().create('Gaussian', | |||||
norm_bound=norm_clip, | |||||
initial_noise_multiplier=initial_noise_multiplier) | |||||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
noise_mech = NoiseMechanismsFactory().create('Gaussian', | |||||
norm_bound=norm_clip, | |||||
initial_noise_multiplier=initial_noise_multiplier) | |||||
clip_mech = ClipMechanismsFactory().create('Gaussian', | |||||
decay_policy='Linear', | |||||
learning_rate=0.01, | |||||
target_unclipped_quantile=0.9, | |||||
fraction_stddev=0.01) | |||||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, | |||||
momentum=0.9) | |||||
model = DPModel(micro_batches=2, | model = DPModel(micro_batches=2, | ||||
clip_mech=clip_mech, | |||||
norm_clip=norm_clip, | norm_clip=norm_clip, | ||||
mech=mech, | |||||
noise_mech=noise_mech, | |||||
network=network, | network=network, | ||||
loss_fn=loss, | loss_fn=loss, | ||||
optimizer=net_opt, | optimizer=net_opt, | ||||
metrics=None) | metrics=None) | ||||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||||
ms_ds.set_dataset_size(batch_size * batches) | |||||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||||
['data', 'label']) | |||||
ms_ds.set_dataset_size(batch_size*batches) | |||||
model.train(epochs, ms_ds, dataset_sink_mode=False) | model.train(epochs, ms_ds, dataset_sink_mode=False) | ||||
@@ -112,17 +131,25 @@ def test_dp_model_with_graph_mode_ada_gaussian(): | |||||
batches = 128 | batches = 128 | ||||
epochs = 1 | epochs = 1 | ||||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | ||||
mech = MechanismsFactory().create('AdaGaussian', | |||||
norm_bound=norm_clip, | |||||
initial_noise_multiplier=initial_noise_multiplier) | |||||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
noise_mech = NoiseMechanismsFactory().create('AdaGaussian', | |||||
norm_bound=norm_clip, | |||||
initial_noise_multiplier=initial_noise_multiplier) | |||||
clip_mech = ClipMechanismsFactory().create('Gaussian', | |||||
decay_policy='Linear', | |||||
learning_rate=0.01, | |||||
target_unclipped_quantile=0.9, | |||||
fraction_stddev=0.01) | |||||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, | |||||
momentum=0.9) | |||||
model = DPModel(micro_batches=2, | model = DPModel(micro_batches=2, | ||||
clip_mech=clip_mech, | |||||
norm_clip=norm_clip, | norm_clip=norm_clip, | ||||
mech=mech, | |||||
noise_mech=noise_mech, | |||||
network=network, | network=network, | ||||
loss_fn=loss, | loss_fn=loss, | ||||
optimizer=net_opt, | optimizer=net_opt, | ||||
metrics=None) | metrics=None) | ||||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||||
ms_ds.set_dataset_size(batch_size * batches) | |||||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||||
['data', 'label']) | |||||
ms_ds.set_dataset_size(batch_size*batches) | |||||
model.train(epochs, ms_ds, dataset_sink_mode=False) | model.train(epochs, ms_ds, dataset_sink_mode=False) |