@@ -20,7 +20,7 @@ from easydict import EasyDict as edict | |||
mnist_cfg = edict({ | |||
'num_classes': 10, # the number of classes of model's output | |||
'lr': 0.01, # the learning rate of model's optimizer | |||
'lr': 0.1, # the learning rate of model's optimizer | |||
'momentum': 0.9, # the momentum value of model's optimizer | |||
'epoch_size': 10, # training epochs | |||
'batch_size': 256, # batch size for training | |||
@@ -33,7 +33,7 @@ mnist_cfg = edict({ | |||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||
'micro_batches': 16, # the number of small batches split from an original batch | |||
'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | |||
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training | |||
'initial_noise_multiplier': 0.2, # the initial multiplication coefficient of the noise added to training | |||
# parameters' gradients | |||
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | |||
@@ -0,0 +1,140 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
python lenet5_dp.py --data_path /YourDataPath --micro_batches=2 | |||
""" | |||
import os | |||
import mindspore.nn as nn | |||
from mindspore import context | |||
from mindspore.train.callback import ModelCheckpoint | |||
from mindspore.train.callback import CheckpointConfig | |||
from mindspore.train.callback import LossMonitor | |||
from mindspore.nn.metrics import Accuracy | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
import mindspore.dataset as ds | |||
import mindspore.dataset.transforms.vision.c_transforms as CV | |||
import mindspore.dataset.transforms.c_transforms as C | |||
from mindspore.dataset.transforms.vision import Inter | |||
import mindspore.common.dtype as mstype | |||
from mindarmour.diff_privacy import DPModel | |||
from mindarmour.diff_privacy import PrivacyMonitorFactory | |||
from mindarmour.diff_privacy import MechanismsFactory | |||
from mindarmour.utils.logger import LogUtil | |||
from lenet5_net import LeNet5 | |||
from lenet5_config import mnist_cfg as cfg | |||
LOGGER = LogUtil.get_instance() | |||
LOGGER.set_level('INFO') | |||
TAG = 'Lenet5_train' | |||
def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||
num_parallel_workers=1, sparse=True): | |||
""" | |||
create dataset for training or testing | |||
""" | |||
# define dataset | |||
ds1 = ds.MnistDataset(data_path) | |||
# define operation parameters | |||
resize_height, resize_width = 32, 32 | |||
rescale = 1.0 / 255.0 | |||
shift = 0.0 | |||
# define map operations | |||
resize_op = CV.Resize((resize_height, resize_width), | |||
interpolation=Inter.LINEAR) | |||
rescale_op = CV.Rescale(rescale, shift) | |||
hwc2chw_op = CV.HWC2CHW() | |||
type_cast_op = C.TypeCast(mstype.int32) | |||
# apply map operations on images | |||
if not sparse: | |||
one_hot_enco = C.OneHot(10) | |||
ds1 = ds1.map(input_columns="label", operations=one_hot_enco, | |||
num_parallel_workers=num_parallel_workers) | |||
type_cast_op = C.TypeCast(mstype.float32) | |||
ds1 = ds1.map(input_columns="label", operations=type_cast_op, | |||
num_parallel_workers=num_parallel_workers) | |||
ds1 = ds1.map(input_columns="image", operations=resize_op, | |||
num_parallel_workers=num_parallel_workers) | |||
ds1 = ds1.map(input_columns="image", operations=rescale_op, | |||
num_parallel_workers=num_parallel_workers) | |||
ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, | |||
num_parallel_workers=num_parallel_workers) | |||
# apply DatasetOps | |||
buffer_size = 10000 | |||
ds1 = ds1.shuffle(buffer_size=buffer_size) | |||
ds1 = ds1.batch(batch_size, drop_remainder=True) | |||
ds1 = ds1.repeat(repeat_size) | |||
return ds1 | |||
if __name__ == "__main__": | |||
# This configure can run both in pynative mode and graph mode | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | |||
network = LeNet5() | |||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||
directory='./trained_ckpt_file/', | |||
config=config_ck) | |||
# get training dataset | |||
ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), | |||
cfg.batch_size, | |||
cfg.epoch_size) | |||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||
raise ValueError("Number of micro_batches should divide evenly batch_size") | |||
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training. | |||
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which | |||
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||
mech = MechanismsFactory().create(cfg.mechanisms, | |||
norm_bound=cfg.l2_norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
net_opt = nn.Momentum(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||
# and delta) while training. | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.l2_norm_bound, | |||
per_print_times=10) | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_clip=cfg.l2_norm_bound, | |||
mech=mech, | |||
network=network, | |||
loss_fn=net_loss, | |||
optimizer=net_opt, | |||
metrics={"Accuracy": Accuracy()}) | |||
LOGGER.info(TAG, "============== Starting Training ==============") | |||
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | |||
dataset_sink_mode=cfg.dataset_sink_mode) | |||
LOGGER.info(TAG, "============== Starting Testing ==============") | |||
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | |||
param_dict = load_checkpoint(ckpt_file_name) | |||
load_param_into_net(network, param_dict) | |||
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) | |||
acc = model.eval(ds_eval, dataset_sink_mode=False) | |||
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) |
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2 | |||
python lenet5_dp_pynative_mode.py --data_path /YourDataPath --micro_batches=2 | |||
""" | |||
import os | |||
@@ -30,8 +30,8 @@ from mindspore.dataset.transforms.vision import Inter | |||
import mindspore.common.dtype as mstype | |||
from mindarmour.diff_privacy import DPModel | |||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||
from mindarmour.diff_privacy import PrivacyMonitorFactory | |||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||
from mindarmour.utils.logger import LogUtil | |||
from lenet5_net import LeNet5 | |||
from lenet5_config import mnist_cfg as cfg | |||
@@ -86,8 +86,8 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||
if __name__ == "__main__": | |||
# This configure just can run in pynative mode. | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | |||
network = LeNet5() | |||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
@@ -103,34 +103,26 @@ if __name__ == "__main__": | |||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||
raise ValueError("Number of micro_batches should divide evenly batch_size") | |||
# Create a factory class of DP optimizer | |||
gaussian_mech = DPOptimizerClassFactory(cfg.micro_batches) | |||
# Set the method of adding noise in gradients while training. Initial_noise_multiplier is suggested to be greater | |||
# than 1.0, otherwise the privacy budget would be huge, which means that the privacy protection effect is weak. | |||
# mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' mechanism while | |||
# be constant with 'Gaussian' mechanism. | |||
gaussian_mech.set_mechanisms(cfg.mechanisms, | |||
norm_bound=cfg.l2_norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
# Wrap the base optimizer for DP training. Momentum optimizer is suggested for LenNet5. | |||
net_opt = gaussian_mech.create(cfg.optimizer)(params=network.trainable_params(), | |||
learning_rate=cfg.lr, | |||
momentum=cfg.momentum) | |||
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training. | |||
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which | |||
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||
dp_opt = DPOptimizerClassFactory(micro_batches=cfg.micro_batches) | |||
dp_opt.set_mechanisms(cfg.mechanisms, | |||
norm_bound=cfg.l2_norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
net_opt = dp_opt.create('Momentum')(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||
# and delta) while training. | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier, | |||
per_print_times=50) | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.l2_norm_bound, | |||
per_print_times=10) | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_clip=cfg.l2_norm_bound, | |||
dp_mech=gaussian_mech.mech, | |||
mech=None, | |||
network=network, | |||
loss_fn=net_loss, | |||
optimizer=net_opt, |
@@ -14,8 +14,6 @@ | |||
""" | |||
Noise Mechanisms. | |||
""" | |||
import numpy as np | |||
from mindspore import Tensor | |||
from mindspore.nn import Cell | |||
from mindspore.ops import operations as P | |||
@@ -24,6 +22,7 @@ from mindspore.common import dtype as mstype | |||
from mindarmour.utils._check_param import check_param_type | |||
from mindarmour.utils._check_param import check_value_positive | |||
from mindarmour.utils._check_param import check_value_non_negative | |||
from mindarmour.utils._check_param import check_param_in_range | |||
@@ -62,7 +61,8 @@ class Mechanisms(Cell): | |||
""" | |||
Basic class of noise generated mechanism. | |||
""" | |||
def construct(self, shape): | |||
def construct(self, gradients): | |||
""" | |||
Construct function. | |||
""" | |||
@@ -78,41 +78,47 @@ class GaussianRandom(Mechanisms): | |||
initial_noise_multiplier(float): Ratio of the standard deviation of | |||
Gaussian noise divided by the norm_bound, which will be used to | |||
calculate privacy spent. Default: 1.5. | |||
mean(float): Average value of random noise. Default: 0.0. | |||
seed(int): Original random seed. Default: 0. | |||
Returns: | |||
Tensor, generated noise. | |||
Tensor, generated noise with shape like given gradients. | |||
Examples: | |||
>>> shape = (3, 2, 4) | |||
>>> gradients = Tensor([0.2, 0.9], mstype.float32) | |||
>>> norm_bound = 1.0 | |||
>>> initial_noise_multiplier = 1.5 | |||
>>> initial_noise_multiplier = 0.1 | |||
>>> net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||
>>> res = net(shape) | |||
>>> res = net(gradients) | |||
>>> print(res) | |||
""" | |||
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5): | |||
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5, mean=0.0, seed=0): | |||
super(GaussianRandom, self).__init__() | |||
self._norm_bound = check_value_positive('norm_bound', norm_bound) | |||
self._norm_bound = Tensor(norm_bound, mstype.float32) | |||
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | |||
initial_noise_multiplier,) | |||
stddev = self._norm_bound*self._initial_noise_multiplier | |||
self._stddev = stddev | |||
self._mean = 0 | |||
def construct(self, shape): | |||
initial_noise_multiplier) | |||
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||
mean = check_param_type('mean', mean, float) | |||
mean = check_value_non_negative('mean', mean) | |||
self._mean = Tensor(mean, mstype.float32) | |||
self._normal = P.Normal(seed=seed) | |||
def construct(self, gradients): | |||
""" | |||
Generated Gaussian noise. | |||
Args: | |||
shape(tuple): The shape of gradients. | |||
gradients(Tensor): The gradients. | |||
Returns: | |||
Tensor, generated noise. | |||
Tensor, generated noise with shape like given gradients. | |||
""" | |||
shape = check_param_type('shape', shape, tuple) | |||
noise = np.random.normal(self._mean, self._stddev, shape) | |||
return Tensor(noise, mstype.float32) | |||
shape = P.Shape()(gradients) | |||
stddev = P.Mul()(self._norm_bound, self._initial_noise_multiplier) | |||
noise = self._normal(shape, self._mean, stddev) | |||
return noise | |||
class AdaGaussianRandom(Mechanisms): | |||
@@ -126,54 +132,60 @@ class AdaGaussianRandom(Mechanisms): | |||
initial_noise_multiplier(float): Ratio of the standard deviation of | |||
Gaussian noise divided by the norm_bound, which will be used to | |||
calculate privacy spent. Default: 5.0. | |||
noise_decay_rate(float): Hyperparameter for controlling the noise decay. | |||
mean(float): Average value of random noise. Default: 0.0 | |||
noise_decay_rate(float): Hyper parameter for controlling the noise decay. | |||
Default: 6e-4. | |||
decay_policy(str): Noise decay strategy include 'Step' and 'Time'. | |||
Default: 'Time'. | |||
seed(int): Original random seed. Default: 0. | |||
Returns: | |||
Tensor, generated noise. | |||
Tensor, generated noise with shape like given gradients. | |||
Examples: | |||
>>> shape = (3, 2, 4) | |||
>>> gradients = Tensor([0.2, 0.9], mstype.float32) | |||
>>> norm_bound = 1.0 | |||
>>> initial_noise_multiplier = 0.1 | |||
>>> noise_decay_rate = 0.5 | |||
>>> initial_noise_multiplier = 5.0 | |||
>>> mean = 0.0 | |||
>>> noise_decay_rate = 6e-4 | |||
>>> decay_policy = "Time" | |||
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, mean | |||
>>> noise_decay_rate, decay_policy) | |||
>>> res = net(shape) | |||
>>> res = net(gradients) | |||
>>> print(res) | |||
""" | |||
def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, | |||
noise_decay_rate=6e-4, decay_policy='Time'): | |||
def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, mean=0.0, | |||
noise_decay_rate=6e-4, decay_policy='Time', seed=0): | |||
super(AdaGaussianRandom, self).__init__() | |||
norm_bound = check_value_positive('norm_bound', norm_bound) | |||
initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | |||
initial_noise_multiplier) | |||
initial_noise_multiplier = Tensor(np.array(initial_noise_multiplier, np.float32)) | |||
self._norm_bound = Tensor(norm_bound, mstype.float32) | |||
initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||
self._initial_noise_multiplier = Parameter(initial_noise_multiplier, | |||
name='initial_noise_multiplier') | |||
self._stddev = P.Mul()(self._norm_bound, self._initial_noise_multiplier) | |||
self._noise_multiplier = Parameter(initial_noise_multiplier, | |||
name='noise_multiplier') | |||
norm_bound = check_value_positive('norm_bound', norm_bound) | |||
self._norm_bound = Tensor(np.array(norm_bound, np.float32)) | |||
mean = check_param_type('mean', mean, float) | |||
mean = check_value_non_negative('mean', mean) | |||
self._mean = Tensor(mean, mstype.float32) | |||
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | |||
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | |||
self._noise_decay_rate = Tensor(np.array(noise_decay_rate, np.float32)) | |||
self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32) | |||
if decay_policy not in ['Time', 'Step']: | |||
raise NameError("The decay_policy must be in ['Time', 'Step'], but " | |||
"get {}".format(decay_policy)) | |||
self._decay_policy = decay_policy | |||
self._mean = 0.0 | |||
self._sub = P.Sub() | |||
self._mul = P.Mul() | |||
self._add = P.TensorAdd() | |||
self._div = P.Div() | |||
self._stddev = self._update_stddev() | |||
self._dtype = mstype.float32 | |||
self._normal = P.Normal(seed=seed) | |||
self._assign = P.Assign() | |||
def _update_multiplier(self): | |||
""" Update multiplier. """ | |||
@@ -181,31 +193,32 @@ class AdaGaussianRandom(Mechanisms): | |||
temp = self._div(self._initial_noise_multiplier, | |||
self._noise_multiplier) | |||
temp = self._add(temp, self._noise_decay_rate) | |||
temp = self._div(self._initial_noise_multiplier, temp) | |||
self._noise_multiplier = Parameter(temp, name='noise_multiplier') | |||
self._noise_multiplier = self._assign(self._noise_multiplier, | |||
self._div(self._initial_noise_multiplier, temp)) | |||
else: | |||
one = Tensor(1, self._dtype) | |||
temp = self._sub(one, self._noise_decay_rate) | |||
temp = self._mul(temp, self._noise_multiplier) | |||
self._noise_multiplier = Parameter(temp, name='noise_multiplier') | |||
self._noise_multiplier = self._assign(self._noise_multiplier, self._mul(temp, self._noise_multiplier)) | |||
return self._noise_multiplier | |||
def _update_stddev(self): | |||
self._stddev = self._mul(self._noise_multiplier, self._norm_bound) | |||
self._stddev = self._assign(self._stddev, self._mul(self._noise_multiplier, self._norm_bound)) | |||
return self._stddev | |||
def construct(self, shape): | |||
def construct(self, gradients): | |||
""" | |||
Generate adaptive Gaussian noise. | |||
Args: | |||
shape(tuple): The shape of gradients. | |||
gradients(Tensor): The gradients. | |||
Returns: | |||
Tensor, generated noise. | |||
Tensor, generated noise with shape like given gradients. | |||
""" | |||
shape = check_param_type('shape', shape, tuple) | |||
noise = np.random.normal(self._mean, self._stddev.asnumpy(), | |||
shape) | |||
self._update_multiplier() | |||
self._update_stddev() | |||
return Tensor(noise, mstype.float32) | |||
shape = P.Shape()(gradients) | |||
noise = self._normal(shape, self._mean, self._stddev) | |||
# pylint: disable=unused-variable | |||
mt = self._update_multiplier() | |||
# pylint: disable=unused-variable | |||
std = self._update_stddev() | |||
return noise |
@@ -14,13 +14,37 @@ | |||
""" | |||
Differential privacy optimizer. | |||
""" | |||
import mindspore as ms | |||
from mindspore import nn | |||
from mindspore import Tensor | |||
from mindspore.ops import composite as C | |||
from mindspore.ops import operations as P | |||
from mindspore.ops import functional as F | |||
from mindspore.common import dtype as mstype | |||
from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory | |||
from mindarmour.utils._check_param import check_int_positive | |||
_grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
_reciprocal = P.Reciprocal() | |||
@_grad_scale.register("Tensor", "Tensor") | |||
def tensor_grad_scale(scale, grad): | |||
""" grad scaling """ | |||
return grad * _reciprocal(scale) | |||
class _TupleAdd(nn.Cell): | |||
def __init__(self): | |||
super(_TupleAdd, self).__init__() | |||
self.add = P.TensorAdd() | |||
self.hyper_map = C.HyperMap() | |||
def construct(self, input1, input2): | |||
"""Add two tuple of data.""" | |||
out = self.hyper_map(self.add, input1, input2) | |||
return out | |||
class DPOptimizerClassFactory: | |||
""" | |||
@@ -36,9 +60,10 @@ class DPOptimizerClassFactory: | |||
>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2) | |||
>>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5) | |||
>>> net_opt = GaussianSGD.create('Momentum')(params=network.trainable_params(), | |||
>>> learning_rate=cfg.lr, | |||
>>> momentum=cfg.momentum) | |||
>>> learning_rate=cfg.lr, | |||
>>> momentum=cfg.momentum) | |||
""" | |||
def __init__(self, micro_batches=2): | |||
self._mech_factory = MechanismsFactory() | |||
self.mech = None | |||
@@ -78,6 +103,7 @@ class DPOptimizerClassFactory: | |||
""" | |||
Wrap original mindspore optimizer with `self._mech`. | |||
""" | |||
class DPOptimizer(cls): | |||
""" | |||
Initialize the DPOptimizerClass. | |||
@@ -85,23 +111,22 @@ class DPOptimizerClassFactory: | |||
Returns: | |||
Optimizer, Optimizer class. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super(DPOptimizer, self).__init__(*args, **kwargs) | |||
self._mech = mech | |||
self._tuple_add = _TupleAdd() | |||
self._hyper_map = C.HyperMap() | |||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||
def construct(self, gradients): | |||
""" | |||
construct a compute flow. | |||
""" | |||
g_len = len(gradients) | |||
gradient_noise = list(gradients) | |||
for i in range(g_len): | |||
gradient_noise[i] = gradient_noise[i].asnumpy() | |||
gradient_noise[i] = self._mech(gradient_noise[i].shape).asnumpy() + gradient_noise[i] | |||
gradient_noise[i] = gradient_noise[i] / micro_batches | |||
gradient_noise[i] = Tensor(gradient_noise[i], ms.float32) | |||
gradients = tuple(gradient_noise) | |||
gradients = super(DPOptimizer, self).construct(gradients) | |||
grad_noise = self._hyper_map(self._mech, gradients) | |||
grads = self._tuple_add(gradients, grad_noise) | |||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||
gradients = super(DPOptimizer, self).construct(grads) | |||
return gradients | |||
return DPOptimizer |
@@ -16,7 +16,6 @@ Differential privacy model. | |||
""" | |||
from easydict import EasyDict as edict | |||
import mindspore as ms | |||
from mindspore.train.model import Model | |||
from mindspore._checkparam import Validator as validator | |||
from mindspore._checkparam import Rel | |||
@@ -48,21 +47,19 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow | |||
from mindspore.nn import Cell | |||
from mindspore import ParameterTuple | |||
from mindarmour.diff_privacy.mechanisms import mechanisms | |||
from mindarmour.utils._check_param import check_param_type | |||
from mindarmour.utils._check_param import check_value_positive | |||
from mindarmour.utils._check_param import check_int_positive | |||
GRADIENT_CLIP_TYPE = 1 | |||
grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
reciprocal = P.Reciprocal() | |||
_grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
_reciprocal = P.Reciprocal() | |||
@grad_scale.register("Tensor", "Tensor") | |||
@_grad_scale.register("Tensor", "Tensor") | |||
def tensor_grad_scale(scale, grad): | |||
""" grad scaling """ | |||
return grad*reciprocal(scale) | |||
return grad * F.cast(_reciprocal(scale), F.dtype(grad)) | |||
class DPModel(Model): | |||
@@ -72,7 +69,7 @@ class DPModel(Model): | |||
Args: | |||
micro_batches (int): The number of small batches split from an original batch. Default: 2. | |||
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | |||
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||
Examples: | |||
>>> class Net(nn.Cell): | |||
@@ -94,32 +91,37 @@ class DPModel(Model): | |||
>>> | |||
>>> net = Net() | |||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||
>>> gaussian_mech = DPOptimizerClassFactory() | |||
>>> gaussian_mech.set_mechanisms('Gaussian', | |||
>>> norm_bound=args.l2_norm_bound, | |||
>>> initial_noise_multiplier=args.initial_noise_multiplier) | |||
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||
>>> mech = MechanismsFactory().create('Gaussian', | |||
>>> norm_bound=args.l2_norm_bound, | |||
>>> initial_noise_multiplier=args.initial_noise_multiplier) | |||
>>> model = DPModel(micro_batches=2, | |||
>>> norm_clip=1.0, | |||
>>> dp_mech=gaussian_mech.mech, | |||
>>> mech=mech, | |||
>>> network=net, | |||
>>> loss_fn=loss, | |||
>>> optimizer=optim, | |||
>>> optimizer=net_opt, | |||
>>> metrics=None) | |||
>>> dataset = get_dataset() | |||
>>> model.train(2, dataset) | |||
""" | |||
def __init__(self, micro_batches=2, norm_clip=1.0, dp_mech=None, **kwargs): | |||
def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs): | |||
if micro_batches: | |||
self._micro_batches = check_int_positive('micro_batches', micro_batches) | |||
else: | |||
self._micro_batches = None | |||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||
self._norm_clip = check_value_positive('norm_clip', norm_clip) | |||
if isinstance(dp_mech, mechanisms.Mechanisms): | |||
self._dp_mech = dp_mech | |||
else: | |||
raise TypeError('dp mechanisms should be instance of class Mechansms, but got {}'.format(type(dp_mech))) | |||
float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float) | |||
self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip) | |||
if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||
raise ValueError('DPOptimizer is not supported while mech is not None') | |||
if mech is None: | |||
if "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||
if context.get_context('mode') != context.PYNATIVE_MODE: | |||
raise ValueError('DPOptimizer just support pynative mode currently.') | |||
else: | |||
raise ValueError('DPModel should set mech or DPOptimizer configure, please refer to example.') | |||
self._mech = mech | |||
super(DPModel, self).__init__(**kwargs) | |||
def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): | |||
@@ -179,14 +181,14 @@ class DPModel(Model): | |||
scale_update_cell=update_cell, | |||
micro_batches=self._micro_batches, | |||
l2_norm_clip=self._norm_clip, | |||
mech=self._dp_mech).set_train() | |||
mech=self._mech).set_train() | |||
return network | |||
network = _TrainOneStepCell(network, | |||
optimizer, | |||
loss_scale, | |||
micro_batches=self._micro_batches, | |||
l2_norm_clip=self._norm_clip, | |||
mech=self._dp_mech).set_train() | |||
mech=self._mech).set_train() | |||
return network | |||
def _build_train_network(self): | |||
@@ -244,6 +246,7 @@ class _ClipGradients(nn.Cell): | |||
Outputs: | |||
tuple[Tensor], clipped gradients. | |||
""" | |||
def __init__(self): | |||
super(_ClipGradients, self).__init__() | |||
self.clip_by_norm = nn.ClipByNorm() | |||
@@ -253,7 +256,8 @@ class _ClipGradients(nn.Cell): | |||
""" | |||
construct a compute flow. | |||
""" | |||
if clip_type not in (0, 1): | |||
# pylint: disable=consider-using-in | |||
if clip_type != 0 and clip_type != 1: | |||
return grads | |||
new_grads = () | |||
@@ -268,6 +272,18 @@ class _ClipGradients(nn.Cell): | |||
return new_grads | |||
class _TupleAdd(nn.Cell): | |||
def __init__(self): | |||
super(_TupleAdd, self).__init__() | |||
self.add = P.TensorAdd() | |||
self.hyper_map = C.HyperMap() | |||
def construct(self, input1, input2): | |||
"""Add two tuple of data.""" | |||
out = self.hyper_map(self.add, input1, input2) | |||
return out | |||
class _TrainOneStepWithLossScaleCell(Cell): | |||
r""" | |||
Network training with loss scaling. | |||
@@ -347,6 +363,9 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
self._split = P.Split(0, self._micro_batches) | |||
self._clip_by_global_norm = _ClipGradients() | |||
self._mech = mech | |||
self._tuple_add = _TupleAdd() | |||
self._hyper_map = C.HyperMap() | |||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||
def construct(self, data, label, sens=None): | |||
""" | |||
@@ -368,32 +387,28 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
weights = self.weights | |||
record_datas = self._split(data) | |||
record_labels = self._split(label) | |||
grads = () | |||
# first index | |||
loss = self.network(record_datas[0], record_labels[0]) | |||
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | |||
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||
grad_sum = list(record_grad) | |||
grad_len = len(record_grad) | |||
for i in range(grad_len): | |||
grad_sum[i] = grad_sum[i].asnumpy() | |||
grads = record_grad | |||
total_loss = loss | |||
for i in range(1, self._micro_batches): | |||
loss = self.network(record_datas[i], record_labels[i]) | |||
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | |||
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||
for j in range(grad_len): | |||
grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy() | |||
grads = self._tuple_add(grads, record_grad) | |||
total_loss = P.TensorAdd()(total_loss, loss) | |||
loss = P.Div()(total_loss, self._micro_float) | |||
for i in range(grad_len): | |||
grad_sum[i] = Tensor(grad_sum[i], ms.float32) | |||
grads = tuple(grad_sum) | |||
loss = self.network(data, label) | |||
if self._mech is not None: | |||
grad_noise = self._hyper_map(self._mech, grads) | |||
grads = self._tuple_add(grads, grad_noise) | |||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) | |||
# apply grad reducer on grads | |||
grads = self.grad_reducer(grads) | |||
# get the overflow buffer | |||
@@ -474,6 +489,9 @@ class _TrainOneStepCell(Cell): | |||
self._split = P.Split(0, self._micro_batches) | |||
self._clip_by_global_norm = _ClipGradients() | |||
self._mech = mech | |||
self._tuple_add = _TupleAdd() | |||
self._hyper_map = C.HyperMap() | |||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||
def construct(self, data, label): | |||
""" | |||
@@ -486,23 +504,21 @@ class _TrainOneStepCell(Cell): | |||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||
grad_sum = list(record_grad) | |||
grad_len = len(record_grad) | |||
for i in range(grad_len): | |||
grad_sum[i] = grad_sum[i].asnumpy() | |||
grads = record_grad | |||
total_loss = loss | |||
for i in range(1, self._micro_batches): | |||
loss = self.network(record_datas[i], record_labels[i]) | |||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | |||
for j in range(grad_len): | |||
grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy() | |||
for i in range(grad_len): | |||
grad_sum[i] = Tensor(grad_sum[i], ms.float32) | |||
grads = tuple(grad_sum) | |||
loss = self.network(data, label) | |||
grads = self._tuple_add(grads, record_grad) | |||
total_loss = P.TensorAdd()(total_loss, loss) | |||
loss = P.Div()(total_loss, self._micro_float) | |||
if self._mech is not None: | |||
grad_noise = self._hyper_map(self._mech, grads) | |||
grads = self._tuple_add(grads, grad_noise) | |||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||
if self.reducer_flag: | |||
# apply grad reducer on grads | |||
@@ -18,7 +18,7 @@ from setuptools import setup | |||
from setuptools.command.egg_info import egg_info | |||
from setuptools.command.build_py import build_py | |||
version = '0.3.0' | |||
version = '0.5.0' | |||
cur_dir = os.path.dirname(os.path.realpath(__file__)) | |||
pkg_dir = os.path.join(cur_dir, 'build') | |||
@@ -17,6 +17,8 @@ different Privacy test. | |||
import pytest | |||
from mindspore import context | |||
from mindspore import Tensor | |||
from mindspore.common import dtype as mstype | |||
from mindarmour.diff_privacy import GaussianRandom | |||
from mindarmour.diff_privacy import AdaGaussianRandom | |||
from mindarmour.diff_privacy import MechanismsFactory | |||
@@ -26,13 +28,13 @@ from mindarmour.diff_privacy import MechanismsFactory | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_gaussian(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
shape = (3, 2, 4) | |||
def test_graph_gaussian(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
grad = Tensor([3, 2, 4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||
res = net(shape) | |||
res = net(grad) | |||
print(res) | |||
@@ -40,42 +42,99 @@ def test_gaussian(): | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_ada_gaussian(): | |||
def test_pynative_gaussian(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
shape = (3, 2, 4) | |||
grad = Tensor([3, 2, 4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||
res = net(grad) | |||
print(res) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_graph_ada_gaussian(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
grad = Tensor([3, 2, 4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
noise_decay_rate = 0.5 | |||
decay_policy = "Step" | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||
noise_decay_rate, decay_policy) | |||
res = net(shape) | |||
noise_decay_rate=alpha, decay_policy=decay_policy) | |||
res = net(grad) | |||
print(res) | |||
def test_factory(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
shape = (3, 2, 4) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_graph_factory(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
grad = Tensor([3, 2, 4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
noise_decay_rate = 0.5 | |||
decay_policy = "Step" | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
noise_mechanism = MechanismsFactory() | |||
noise_construct = noise_mechanism.create('Gaussian', | |||
norm_bound, | |||
initial_noise_multiplier) | |||
noise = noise_construct(shape) | |||
noise = noise_construct(grad) | |||
print('Gaussian noise: ', noise) | |||
ada_mechanism = MechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate, | |||
decay_policy) | |||
ada_noise = ada_noise_construct(shape) | |||
noise_decay_rate=alpha, | |||
decay_policy=decay_policy) | |||
ada_noise = ada_noise_construct(grad) | |||
print('ada noise: ', ada_noise) | |||
if __name__ == '__main__': | |||
# device_target can be "CPU", "GPU" or "Ascend" | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_pynative_ada_gaussian(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
grad = Tensor([3, 2, 4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||
noise_decay_rate=alpha, decay_policy=decay_policy) | |||
res = net(grad) | |||
print(res) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_pynative_factory(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
grad = Tensor([3, 2, 4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
noise_mechanism = MechanismsFactory() | |||
noise_construct = noise_mechanism.create('Gaussian', | |||
norm_bound, | |||
initial_noise_multiplier) | |||
noise = noise_construct(grad) | |||
print('Gaussian noise: ', noise) | |||
ada_mechanism = MechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
decay_policy=decay_policy) | |||
ada_noise = ada_noise_construct(grad) | |||
print('ada noise: ', ada_noise) |
@@ -21,13 +21,15 @@ from mindspore import nn | |||
from mindspore import context | |||
import mindspore.dataset as ds | |||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||
from mindarmour.diff_privacy import DPModel | |||
from mindarmour.diff_privacy import MechanismsFactory | |||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||
from test_network import LeNet5 | |||
def dataset_generator(batch_size, batches): | |||
"""mock training data.""" | |||
data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) | |||
label = np.random.randint(0, 10, batches * batch_size).astype(np.int32) | |||
for i in range(batches): | |||
@@ -39,7 +41,7 @@ def dataset_generator(batch_size, batches): | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_dp_model(): | |||
def test_dp_model_pynative_mode(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
l2_norm_bound = 1.0 | |||
initial_noise_multiplier = 0.01 | |||
@@ -47,21 +49,50 @@ def test_dp_model(): | |||
batch_size = 32 | |||
batches = 128 | |||
epochs = 1 | |||
micro_batches = 2 | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches) | |||
factory_opt.set_mechanisms('Gaussian', | |||
norm_bound=l2_norm_bound, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
model = DPModel(micro_batches=micro_batches, | |||
norm_clip=l2_norm_bound, | |||
mech=None, | |||
network=network, | |||
loss_fn=loss, | |||
optimizer=net_opt, | |||
metrics=None) | |||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||
ms_ds.set_dataset_size(batch_size * batches) | |||
model.train(epochs, ms_ds, dataset_sink_mode=False) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_dp_model_with_graph_mode(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
l2_norm_bound = 1.0 | |||
initial_noise_multiplier = 0.01 | |||
network = LeNet5() | |||
batch_size = 32 | |||
batches = 128 | |||
epochs = 1 | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
gaussian_mech = DPOptimizerClassFactory(micro_batches=2) | |||
gaussian_mech.set_mechanisms('Gaussian', | |||
norm_bound=l2_norm_bound, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), | |||
learning_rate=0.1, | |||
momentum=0.9) | |||
mech = MechanismsFactory().create('Gaussian', | |||
norm_bound=l2_norm_bound, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
model = DPModel(micro_batches=2, | |||
norm_clip=l2_norm_bound, | |||
dp_mech=gaussian_mech.mech, | |||
mech=mech, | |||
network=network, | |||
loss_fn=loss, | |||
optimizer=net_opt, | |||
metrics=None) | |||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||
ms_ds.set_dataset_size(batch_size * batches) | |||
model.train(epochs, ms_ds) | |||
model.train(epochs, ms_ds, dataset_sink_mode=False) |
@@ -21,6 +21,7 @@ from mindarmour.diff_privacy import DPOptimizerClassFactory | |||
from test_network import LeNet5 | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||