@@ -20,7 +20,7 @@ from easydict import EasyDict as edict | |||||
mnist_cfg = edict({ | mnist_cfg = edict({ | ||||
'num_classes': 10, # the number of classes of model's output | 'num_classes': 10, # the number of classes of model's output | ||||
'lr': 0.01, # the learning rate of model's optimizer | |||||
'lr': 0.1, # the learning rate of model's optimizer | |||||
'momentum': 0.9, # the momentum value of model's optimizer | 'momentum': 0.9, # the momentum value of model's optimizer | ||||
'epoch_size': 10, # training epochs | 'epoch_size': 10, # training epochs | ||||
'batch_size': 256, # batch size for training | 'batch_size': 256, # batch size for training | ||||
@@ -33,7 +33,7 @@ mnist_cfg = edict({ | |||||
'dataset_sink_mode': False, # whether deliver all training data to device one time | 'dataset_sink_mode': False, # whether deliver all training data to device one time | ||||
'micro_batches': 16, # the number of small batches split from an original batch | 'micro_batches': 16, # the number of small batches split from an original batch | ||||
'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | 'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | ||||
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training | |||||
'initial_noise_multiplier': 0.2, # the initial multiplication coefficient of the noise added to training | |||||
# parameters' gradients | # parameters' gradients | ||||
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | 'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | ||||
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | 'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | ||||
@@ -0,0 +1,140 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
python lenet5_dp.py --data_path /YourDataPath --micro_batches=2 | |||||
""" | |||||
import os | |||||
import mindspore.nn as nn | |||||
from mindspore import context | |||||
from mindspore.train.callback import ModelCheckpoint | |||||
from mindspore.train.callback import CheckpointConfig | |||||
from mindspore.train.callback import LossMonitor | |||||
from mindspore.nn.metrics import Accuracy | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
import mindspore.dataset as ds | |||||
import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
import mindspore.dataset.transforms.c_transforms as C | |||||
from mindspore.dataset.transforms.vision import Inter | |||||
import mindspore.common.dtype as mstype | |||||
from mindarmour.diff_privacy import DPModel | |||||
from mindarmour.diff_privacy import PrivacyMonitorFactory | |||||
from mindarmour.diff_privacy import MechanismsFactory | |||||
from mindarmour.utils.logger import LogUtil | |||||
from lenet5_net import LeNet5 | |||||
from lenet5_config import mnist_cfg as cfg | |||||
LOGGER = LogUtil.get_instance() | |||||
LOGGER.set_level('INFO') | |||||
TAG = 'Lenet5_train' | |||||
def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||||
num_parallel_workers=1, sparse=True): | |||||
""" | |||||
create dataset for training or testing | |||||
""" | |||||
# define dataset | |||||
ds1 = ds.MnistDataset(data_path) | |||||
# define operation parameters | |||||
resize_height, resize_width = 32, 32 | |||||
rescale = 1.0 / 255.0 | |||||
shift = 0.0 | |||||
# define map operations | |||||
resize_op = CV.Resize((resize_height, resize_width), | |||||
interpolation=Inter.LINEAR) | |||||
rescale_op = CV.Rescale(rescale, shift) | |||||
hwc2chw_op = CV.HWC2CHW() | |||||
type_cast_op = C.TypeCast(mstype.int32) | |||||
# apply map operations on images | |||||
if not sparse: | |||||
one_hot_enco = C.OneHot(10) | |||||
ds1 = ds1.map(input_columns="label", operations=one_hot_enco, | |||||
num_parallel_workers=num_parallel_workers) | |||||
type_cast_op = C.TypeCast(mstype.float32) | |||||
ds1 = ds1.map(input_columns="label", operations=type_cast_op, | |||||
num_parallel_workers=num_parallel_workers) | |||||
ds1 = ds1.map(input_columns="image", operations=resize_op, | |||||
num_parallel_workers=num_parallel_workers) | |||||
ds1 = ds1.map(input_columns="image", operations=rescale_op, | |||||
num_parallel_workers=num_parallel_workers) | |||||
ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, | |||||
num_parallel_workers=num_parallel_workers) | |||||
# apply DatasetOps | |||||
buffer_size = 10000 | |||||
ds1 = ds1.shuffle(buffer_size=buffer_size) | |||||
ds1 = ds1.batch(batch_size, drop_remainder=True) | |||||
ds1 = ds1.repeat(repeat_size) | |||||
return ds1 | |||||
if __name__ == "__main__": | |||||
# This configure can run both in pynative mode and graph mode | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | |||||
network = LeNet5() | |||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
directory='./trained_ckpt_file/', | |||||
config=config_ck) | |||||
# get training dataset | |||||
ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), | |||||
cfg.batch_size, | |||||
cfg.epoch_size) | |||||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||||
raise ValueError("Number of micro_batches should divide evenly batch_size") | |||||
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training. | |||||
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which | |||||
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||||
mech = MechanismsFactory().create(cfg.mechanisms, | |||||
norm_bound=cfg.l2_norm_bound, | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||||
net_opt = nn.Momentum(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||||
# and delta) while training. | |||||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||||
num_samples=60000, | |||||
batch_size=cfg.batch_size, | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.l2_norm_bound, | |||||
per_print_times=10) | |||||
# Create the DP model for training. | |||||
model = DPModel(micro_batches=cfg.micro_batches, | |||||
norm_clip=cfg.l2_norm_bound, | |||||
mech=mech, | |||||
network=network, | |||||
loss_fn=net_loss, | |||||
optimizer=net_opt, | |||||
metrics={"Accuracy": Accuracy()}) | |||||
LOGGER.info(TAG, "============== Starting Training ==============") | |||||
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | |||||
dataset_sink_mode=cfg.dataset_sink_mode) | |||||
LOGGER.info(TAG, "============== Starting Testing ==============") | |||||
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | |||||
param_dict = load_checkpoint(ckpt_file_name) | |||||
load_param_into_net(network, param_dict) | |||||
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) | |||||
acc = model.eval(ds_eval, dataset_sink_mode=False) | |||||
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) |
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
""" | """ | ||||
python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2 | |||||
python lenet5_dp_pynative_mode.py --data_path /YourDataPath --micro_batches=2 | |||||
""" | """ | ||||
import os | import os | ||||
@@ -30,8 +30,8 @@ from mindspore.dataset.transforms.vision import Inter | |||||
import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
from mindarmour.diff_privacy import DPModel | from mindarmour.diff_privacy import DPModel | ||||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
from mindarmour.diff_privacy import PrivacyMonitorFactory | from mindarmour.diff_privacy import PrivacyMonitorFactory | ||||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from lenet5_net import LeNet5 | from lenet5_net import LeNet5 | ||||
from lenet5_config import mnist_cfg as cfg | from lenet5_config import mnist_cfg as cfg | ||||
@@ -86,8 +86,8 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
# This configure just can run in pynative mode. | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | ||||
network = LeNet5() | network = LeNet5() | ||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | ||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | ||||
@@ -103,34 +103,26 @@ if __name__ == "__main__": | |||||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | ||||
raise ValueError("Number of micro_batches should divide evenly batch_size") | raise ValueError("Number of micro_batches should divide evenly batch_size") | ||||
# Create a factory class of DP optimizer | |||||
gaussian_mech = DPOptimizerClassFactory(cfg.micro_batches) | |||||
# Set the method of adding noise in gradients while training. Initial_noise_multiplier is suggested to be greater | |||||
# than 1.0, otherwise the privacy budget would be huge, which means that the privacy protection effect is weak. | |||||
# mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' mechanism while | |||||
# be constant with 'Gaussian' mechanism. | |||||
gaussian_mech.set_mechanisms(cfg.mechanisms, | |||||
norm_bound=cfg.l2_norm_bound, | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||||
# Wrap the base optimizer for DP training. Momentum optimizer is suggested for LenNet5. | |||||
net_opt = gaussian_mech.create(cfg.optimizer)(params=network.trainable_params(), | |||||
learning_rate=cfg.lr, | |||||
momentum=cfg.momentum) | |||||
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training. | |||||
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which | |||||
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||||
dp_opt = DPOptimizerClassFactory(micro_batches=cfg.micro_batches) | |||||
dp_opt.set_mechanisms(cfg.mechanisms, | |||||
norm_bound=cfg.l2_norm_bound, | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||||
net_opt = dp_opt.create('Momentum')(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | # Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | ||||
# and delta) while training. | # and delta) while training. | ||||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | rdp_monitor = PrivacyMonitorFactory.create('rdp', | ||||
num_samples=60000, | num_samples=60000, | ||||
batch_size=cfg.batch_size, | batch_size=cfg.batch_size, | ||||
initial_noise_multiplier=cfg.initial_noise_multiplier, | |||||
per_print_times=50) | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.l2_norm_bound, | |||||
per_print_times=10) | |||||
# Create the DP model for training. | # Create the DP model for training. | ||||
model = DPModel(micro_batches=cfg.micro_batches, | model = DPModel(micro_batches=cfg.micro_batches, | ||||
norm_clip=cfg.l2_norm_bound, | norm_clip=cfg.l2_norm_bound, | ||||
dp_mech=gaussian_mech.mech, | |||||
mech=None, | |||||
network=network, | network=network, | ||||
loss_fn=net_loss, | loss_fn=net_loss, | ||||
optimizer=net_opt, | optimizer=net_opt, |
@@ -14,8 +14,6 @@ | |||||
""" | """ | ||||
Noise Mechanisms. | Noise Mechanisms. | ||||
""" | """ | ||||
import numpy as np | |||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
@@ -24,6 +22,7 @@ from mindspore.common import dtype as mstype | |||||
from mindarmour.utils._check_param import check_param_type | from mindarmour.utils._check_param import check_param_type | ||||
from mindarmour.utils._check_param import check_value_positive | from mindarmour.utils._check_param import check_value_positive | ||||
from mindarmour.utils._check_param import check_value_non_negative | |||||
from mindarmour.utils._check_param import check_param_in_range | from mindarmour.utils._check_param import check_param_in_range | ||||
@@ -62,7 +61,8 @@ class Mechanisms(Cell): | |||||
""" | """ | ||||
Basic class of noise generated mechanism. | Basic class of noise generated mechanism. | ||||
""" | """ | ||||
def construct(self, shape): | |||||
def construct(self, gradients): | |||||
""" | """ | ||||
Construct function. | Construct function. | ||||
""" | """ | ||||
@@ -78,41 +78,47 @@ class GaussianRandom(Mechanisms): | |||||
initial_noise_multiplier(float): Ratio of the standard deviation of | initial_noise_multiplier(float): Ratio of the standard deviation of | ||||
Gaussian noise divided by the norm_bound, which will be used to | Gaussian noise divided by the norm_bound, which will be used to | ||||
calculate privacy spent. Default: 1.5. | calculate privacy spent. Default: 1.5. | ||||
mean(float): Average value of random noise. Default: 0.0. | |||||
seed(int): Original random seed. Default: 0. | |||||
Returns: | Returns: | ||||
Tensor, generated noise. | |||||
Tensor, generated noise with shape like given gradients. | |||||
Examples: | Examples: | ||||
>>> shape = (3, 2, 4) | |||||
>>> gradients = Tensor([0.2, 0.9], mstype.float32) | |||||
>>> norm_bound = 1.0 | >>> norm_bound = 1.0 | ||||
>>> initial_noise_multiplier = 1.5 | |||||
>>> initial_noise_multiplier = 0.1 | |||||
>>> net = GaussianRandom(norm_bound, initial_noise_multiplier) | >>> net = GaussianRandom(norm_bound, initial_noise_multiplier) | ||||
>>> res = net(shape) | |||||
>>> res = net(gradients) | |||||
>>> print(res) | >>> print(res) | ||||
""" | """ | ||||
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5): | |||||
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5, mean=0.0, seed=0): | |||||
super(GaussianRandom, self).__init__() | super(GaussianRandom, self).__init__() | ||||
self._norm_bound = check_value_positive('norm_bound', norm_bound) | self._norm_bound = check_value_positive('norm_bound', norm_bound) | ||||
self._norm_bound = Tensor(norm_bound, mstype.float32) | |||||
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | ||||
initial_noise_multiplier,) | |||||
stddev = self._norm_bound*self._initial_noise_multiplier | |||||
self._stddev = stddev | |||||
self._mean = 0 | |||||
def construct(self, shape): | |||||
initial_noise_multiplier) | |||||
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||||
mean = check_param_type('mean', mean, float) | |||||
mean = check_value_non_negative('mean', mean) | |||||
self._mean = Tensor(mean, mstype.float32) | |||||
self._normal = P.Normal(seed=seed) | |||||
def construct(self, gradients): | |||||
""" | """ | ||||
Generated Gaussian noise. | Generated Gaussian noise. | ||||
Args: | Args: | ||||
shape(tuple): The shape of gradients. | |||||
gradients(Tensor): The gradients. | |||||
Returns: | Returns: | ||||
Tensor, generated noise. | |||||
Tensor, generated noise with shape like given gradients. | |||||
""" | """ | ||||
shape = check_param_type('shape', shape, tuple) | |||||
noise = np.random.normal(self._mean, self._stddev, shape) | |||||
return Tensor(noise, mstype.float32) | |||||
shape = P.Shape()(gradients) | |||||
stddev = P.Mul()(self._norm_bound, self._initial_noise_multiplier) | |||||
noise = self._normal(shape, self._mean, stddev) | |||||
return noise | |||||
class AdaGaussianRandom(Mechanisms): | class AdaGaussianRandom(Mechanisms): | ||||
@@ -126,54 +132,60 @@ class AdaGaussianRandom(Mechanisms): | |||||
initial_noise_multiplier(float): Ratio of the standard deviation of | initial_noise_multiplier(float): Ratio of the standard deviation of | ||||
Gaussian noise divided by the norm_bound, which will be used to | Gaussian noise divided by the norm_bound, which will be used to | ||||
calculate privacy spent. Default: 5.0. | calculate privacy spent. Default: 5.0. | ||||
noise_decay_rate(float): Hyperparameter for controlling the noise decay. | |||||
mean(float): Average value of random noise. Default: 0.0 | |||||
noise_decay_rate(float): Hyper parameter for controlling the noise decay. | |||||
Default: 6e-4. | Default: 6e-4. | ||||
decay_policy(str): Noise decay strategy include 'Step' and 'Time'. | decay_policy(str): Noise decay strategy include 'Step' and 'Time'. | ||||
Default: 'Time'. | Default: 'Time'. | ||||
seed(int): Original random seed. Default: 0. | |||||
Returns: | Returns: | ||||
Tensor, generated noise. | |||||
Tensor, generated noise with shape like given gradients. | |||||
Examples: | Examples: | ||||
>>> shape = (3, 2, 4) | |||||
>>> gradients = Tensor([0.2, 0.9], mstype.float32) | |||||
>>> norm_bound = 1.0 | >>> norm_bound = 1.0 | ||||
>>> initial_noise_multiplier = 0.1 | |||||
>>> noise_decay_rate = 0.5 | |||||
>>> initial_noise_multiplier = 5.0 | |||||
>>> mean = 0.0 | |||||
>>> noise_decay_rate = 6e-4 | |||||
>>> decay_policy = "Time" | >>> decay_policy = "Time" | ||||
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||||
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, mean | |||||
>>> noise_decay_rate, decay_policy) | >>> noise_decay_rate, decay_policy) | ||||
>>> res = net(shape) | |||||
>>> res = net(gradients) | |||||
>>> print(res) | >>> print(res) | ||||
""" | """ | ||||
def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, | |||||
noise_decay_rate=6e-4, decay_policy='Time'): | |||||
def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, mean=0.0, | |||||
noise_decay_rate=6e-4, decay_policy='Time', seed=0): | |||||
super(AdaGaussianRandom, self).__init__() | super(AdaGaussianRandom, self).__init__() | ||||
norm_bound = check_value_positive('norm_bound', norm_bound) | |||||
initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | ||||
initial_noise_multiplier) | initial_noise_multiplier) | ||||
initial_noise_multiplier = Tensor(np.array(initial_noise_multiplier, np.float32)) | |||||
self._norm_bound = Tensor(norm_bound, mstype.float32) | |||||
initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||||
self._initial_noise_multiplier = Parameter(initial_noise_multiplier, | self._initial_noise_multiplier = Parameter(initial_noise_multiplier, | ||||
name='initial_noise_multiplier') | name='initial_noise_multiplier') | ||||
self._stddev = P.Mul()(self._norm_bound, self._initial_noise_multiplier) | |||||
self._noise_multiplier = Parameter(initial_noise_multiplier, | self._noise_multiplier = Parameter(initial_noise_multiplier, | ||||
name='noise_multiplier') | name='noise_multiplier') | ||||
norm_bound = check_value_positive('norm_bound', norm_bound) | |||||
self._norm_bound = Tensor(np.array(norm_bound, np.float32)) | |||||
mean = check_param_type('mean', mean, float) | |||||
mean = check_value_non_negative('mean', mean) | |||||
self._mean = Tensor(mean, mstype.float32) | |||||
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | ||||
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | ||||
self._noise_decay_rate = Tensor(np.array(noise_decay_rate, np.float32)) | |||||
self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32) | |||||
if decay_policy not in ['Time', 'Step']: | if decay_policy not in ['Time', 'Step']: | ||||
raise NameError("The decay_policy must be in ['Time', 'Step'], but " | raise NameError("The decay_policy must be in ['Time', 'Step'], but " | ||||
"get {}".format(decay_policy)) | "get {}".format(decay_policy)) | ||||
self._decay_policy = decay_policy | self._decay_policy = decay_policy | ||||
self._mean = 0.0 | |||||
self._sub = P.Sub() | self._sub = P.Sub() | ||||
self._mul = P.Mul() | self._mul = P.Mul() | ||||
self._add = P.TensorAdd() | self._add = P.TensorAdd() | ||||
self._div = P.Div() | self._div = P.Div() | ||||
self._stddev = self._update_stddev() | |||||
self._dtype = mstype.float32 | self._dtype = mstype.float32 | ||||
self._normal = P.Normal(seed=seed) | |||||
self._assign = P.Assign() | |||||
def _update_multiplier(self): | def _update_multiplier(self): | ||||
""" Update multiplier. """ | """ Update multiplier. """ | ||||
@@ -181,31 +193,32 @@ class AdaGaussianRandom(Mechanisms): | |||||
temp = self._div(self._initial_noise_multiplier, | temp = self._div(self._initial_noise_multiplier, | ||||
self._noise_multiplier) | self._noise_multiplier) | ||||
temp = self._add(temp, self._noise_decay_rate) | temp = self._add(temp, self._noise_decay_rate) | ||||
temp = self._div(self._initial_noise_multiplier, temp) | |||||
self._noise_multiplier = Parameter(temp, name='noise_multiplier') | |||||
self._noise_multiplier = self._assign(self._noise_multiplier, | |||||
self._div(self._initial_noise_multiplier, temp)) | |||||
else: | else: | ||||
one = Tensor(1, self._dtype) | one = Tensor(1, self._dtype) | ||||
temp = self._sub(one, self._noise_decay_rate) | temp = self._sub(one, self._noise_decay_rate) | ||||
temp = self._mul(temp, self._noise_multiplier) | |||||
self._noise_multiplier = Parameter(temp, name='noise_multiplier') | |||||
self._noise_multiplier = self._assign(self._noise_multiplier, self._mul(temp, self._noise_multiplier)) | |||||
return self._noise_multiplier | |||||
def _update_stddev(self): | def _update_stddev(self): | ||||
self._stddev = self._mul(self._noise_multiplier, self._norm_bound) | |||||
self._stddev = self._assign(self._stddev, self._mul(self._noise_multiplier, self._norm_bound)) | |||||
return self._stddev | return self._stddev | ||||
def construct(self, shape): | |||||
def construct(self, gradients): | |||||
""" | """ | ||||
Generate adaptive Gaussian noise. | Generate adaptive Gaussian noise. | ||||
Args: | Args: | ||||
shape(tuple): The shape of gradients. | |||||
gradients(Tensor): The gradients. | |||||
Returns: | Returns: | ||||
Tensor, generated noise. | |||||
Tensor, generated noise with shape like given gradients. | |||||
""" | """ | ||||
shape = check_param_type('shape', shape, tuple) | |||||
noise = np.random.normal(self._mean, self._stddev.asnumpy(), | |||||
shape) | |||||
self._update_multiplier() | |||||
self._update_stddev() | |||||
return Tensor(noise, mstype.float32) | |||||
shape = P.Shape()(gradients) | |||||
noise = self._normal(shape, self._mean, self._stddev) | |||||
# pylint: disable=unused-variable | |||||
mt = self._update_multiplier() | |||||
# pylint: disable=unused-variable | |||||
std = self._update_stddev() | |||||
return noise |
@@ -14,13 +14,37 @@ | |||||
""" | """ | ||||
Differential privacy optimizer. | Differential privacy optimizer. | ||||
""" | """ | ||||
import mindspore as ms | |||||
from mindspore import nn | from mindspore import nn | ||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindspore.ops import composite as C | |||||
from mindspore.ops import operations as P | |||||
from mindspore.ops import functional as F | |||||
from mindspore.common import dtype as mstype | |||||
from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory | from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory | ||||
from mindarmour.utils._check_param import check_int_positive | from mindarmour.utils._check_param import check_int_positive | ||||
_grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
_reciprocal = P.Reciprocal() | |||||
@_grad_scale.register("Tensor", "Tensor") | |||||
def tensor_grad_scale(scale, grad): | |||||
""" grad scaling """ | |||||
return grad * _reciprocal(scale) | |||||
class _TupleAdd(nn.Cell): | |||||
def __init__(self): | |||||
super(_TupleAdd, self).__init__() | |||||
self.add = P.TensorAdd() | |||||
self.hyper_map = C.HyperMap() | |||||
def construct(self, input1, input2): | |||||
"""Add two tuple of data.""" | |||||
out = self.hyper_map(self.add, input1, input2) | |||||
return out | |||||
class DPOptimizerClassFactory: | class DPOptimizerClassFactory: | ||||
""" | """ | ||||
@@ -36,9 +60,10 @@ class DPOptimizerClassFactory: | |||||
>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2) | >>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2) | ||||
>>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5) | >>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5) | ||||
>>> net_opt = GaussianSGD.create('Momentum')(params=network.trainable_params(), | >>> net_opt = GaussianSGD.create('Momentum')(params=network.trainable_params(), | ||||
>>> learning_rate=cfg.lr, | |||||
>>> momentum=cfg.momentum) | |||||
>>> learning_rate=cfg.lr, | |||||
>>> momentum=cfg.momentum) | |||||
""" | """ | ||||
def __init__(self, micro_batches=2): | def __init__(self, micro_batches=2): | ||||
self._mech_factory = MechanismsFactory() | self._mech_factory = MechanismsFactory() | ||||
self.mech = None | self.mech = None | ||||
@@ -78,6 +103,7 @@ class DPOptimizerClassFactory: | |||||
""" | """ | ||||
Wrap original mindspore optimizer with `self._mech`. | Wrap original mindspore optimizer with `self._mech`. | ||||
""" | """ | ||||
class DPOptimizer(cls): | class DPOptimizer(cls): | ||||
""" | """ | ||||
Initialize the DPOptimizerClass. | Initialize the DPOptimizerClass. | ||||
@@ -85,23 +111,22 @@ class DPOptimizerClassFactory: | |||||
Returns: | Returns: | ||||
Optimizer, Optimizer class. | Optimizer, Optimizer class. | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(DPOptimizer, self).__init__(*args, **kwargs) | super(DPOptimizer, self).__init__(*args, **kwargs) | ||||
self._mech = mech | self._mech = mech | ||||
self._tuple_add = _TupleAdd() | |||||
self._hyper_map = C.HyperMap() | |||||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||||
def construct(self, gradients): | def construct(self, gradients): | ||||
""" | """ | ||||
construct a compute flow. | construct a compute flow. | ||||
""" | """ | ||||
g_len = len(gradients) | |||||
gradient_noise = list(gradients) | |||||
for i in range(g_len): | |||||
gradient_noise[i] = gradient_noise[i].asnumpy() | |||||
gradient_noise[i] = self._mech(gradient_noise[i].shape).asnumpy() + gradient_noise[i] | |||||
gradient_noise[i] = gradient_noise[i] / micro_batches | |||||
gradient_noise[i] = Tensor(gradient_noise[i], ms.float32) | |||||
gradients = tuple(gradient_noise) | |||||
gradients = super(DPOptimizer, self).construct(gradients) | |||||
grad_noise = self._hyper_map(self._mech, gradients) | |||||
grads = self._tuple_add(gradients, grad_noise) | |||||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||||
gradients = super(DPOptimizer, self).construct(grads) | |||||
return gradients | return gradients | ||||
return DPOptimizer | return DPOptimizer |
@@ -16,7 +16,6 @@ Differential privacy model. | |||||
""" | """ | ||||
from easydict import EasyDict as edict | from easydict import EasyDict as edict | ||||
import mindspore as ms | |||||
from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
@@ -48,21 +47,19 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow | |||||
from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
from mindspore import ParameterTuple | from mindspore import ParameterTuple | ||||
from mindarmour.diff_privacy.mechanisms import mechanisms | |||||
from mindarmour.utils._check_param import check_param_type | from mindarmour.utils._check_param import check_param_type | ||||
from mindarmour.utils._check_param import check_value_positive | from mindarmour.utils._check_param import check_value_positive | ||||
from mindarmour.utils._check_param import check_int_positive | from mindarmour.utils._check_param import check_int_positive | ||||
GRADIENT_CLIP_TYPE = 1 | GRADIENT_CLIP_TYPE = 1 | ||||
grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
reciprocal = P.Reciprocal() | |||||
_grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
_reciprocal = P.Reciprocal() | |||||
@grad_scale.register("Tensor", "Tensor") | |||||
@_grad_scale.register("Tensor", "Tensor") | |||||
def tensor_grad_scale(scale, grad): | def tensor_grad_scale(scale, grad): | ||||
""" grad scaling """ | """ grad scaling """ | ||||
return grad*reciprocal(scale) | |||||
return grad * F.cast(_reciprocal(scale), F.dtype(grad)) | |||||
class DPModel(Model): | class DPModel(Model): | ||||
@@ -72,7 +69,7 @@ class DPModel(Model): | |||||
Args: | Args: | ||||
micro_batches (int): The number of small batches split from an original batch. Default: 2. | micro_batches (int): The number of small batches split from an original batch. Default: 2. | ||||
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | ||||
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||||
Examples: | Examples: | ||||
>>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
@@ -94,32 +91,37 @@ class DPModel(Model): | |||||
>>> | >>> | ||||
>>> net = Net() | >>> net = Net() | ||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | ||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||||
>>> gaussian_mech = DPOptimizerClassFactory() | |||||
>>> gaussian_mech.set_mechanisms('Gaussian', | |||||
>>> norm_bound=args.l2_norm_bound, | |||||
>>> initial_noise_multiplier=args.initial_noise_multiplier) | |||||
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||||
>>> mech = MechanismsFactory().create('Gaussian', | |||||
>>> norm_bound=args.l2_norm_bound, | |||||
>>> initial_noise_multiplier=args.initial_noise_multiplier) | |||||
>>> model = DPModel(micro_batches=2, | >>> model = DPModel(micro_batches=2, | ||||
>>> norm_clip=1.0, | >>> norm_clip=1.0, | ||||
>>> dp_mech=gaussian_mech.mech, | |||||
>>> mech=mech, | |||||
>>> network=net, | >>> network=net, | ||||
>>> loss_fn=loss, | >>> loss_fn=loss, | ||||
>>> optimizer=optim, | |||||
>>> optimizer=net_opt, | |||||
>>> metrics=None) | >>> metrics=None) | ||||
>>> dataset = get_dataset() | >>> dataset = get_dataset() | ||||
>>> model.train(2, dataset) | >>> model.train(2, dataset) | ||||
""" | """ | ||||
def __init__(self, micro_batches=2, norm_clip=1.0, dp_mech=None, **kwargs): | |||||
def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs): | |||||
if micro_batches: | if micro_batches: | ||||
self._micro_batches = check_int_positive('micro_batches', micro_batches) | self._micro_batches = check_int_positive('micro_batches', micro_batches) | ||||
else: | else: | ||||
self._micro_batches = None | self._micro_batches = None | ||||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||||
self._norm_clip = check_value_positive('norm_clip', norm_clip) | |||||
if isinstance(dp_mech, mechanisms.Mechanisms): | |||||
self._dp_mech = dp_mech | |||||
else: | |||||
raise TypeError('dp mechanisms should be instance of class Mechansms, but got {}'.format(type(dp_mech))) | |||||
float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float) | |||||
self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip) | |||||
if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||||
raise ValueError('DPOptimizer is not supported while mech is not None') | |||||
if mech is None: | |||||
if "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||||
if context.get_context('mode') != context.PYNATIVE_MODE: | |||||
raise ValueError('DPOptimizer just support pynative mode currently.') | |||||
else: | |||||
raise ValueError('DPModel should set mech or DPOptimizer configure, please refer to example.') | |||||
self._mech = mech | |||||
super(DPModel, self).__init__(**kwargs) | super(DPModel, self).__init__(**kwargs) | ||||
def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): | def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): | ||||
@@ -179,14 +181,14 @@ class DPModel(Model): | |||||
scale_update_cell=update_cell, | scale_update_cell=update_cell, | ||||
micro_batches=self._micro_batches, | micro_batches=self._micro_batches, | ||||
l2_norm_clip=self._norm_clip, | l2_norm_clip=self._norm_clip, | ||||
mech=self._dp_mech).set_train() | |||||
mech=self._mech).set_train() | |||||
return network | return network | ||||
network = _TrainOneStepCell(network, | network = _TrainOneStepCell(network, | ||||
optimizer, | optimizer, | ||||
loss_scale, | loss_scale, | ||||
micro_batches=self._micro_batches, | micro_batches=self._micro_batches, | ||||
l2_norm_clip=self._norm_clip, | l2_norm_clip=self._norm_clip, | ||||
mech=self._dp_mech).set_train() | |||||
mech=self._mech).set_train() | |||||
return network | return network | ||||
def _build_train_network(self): | def _build_train_network(self): | ||||
@@ -244,6 +246,7 @@ class _ClipGradients(nn.Cell): | |||||
Outputs: | Outputs: | ||||
tuple[Tensor], clipped gradients. | tuple[Tensor], clipped gradients. | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(_ClipGradients, self).__init__() | super(_ClipGradients, self).__init__() | ||||
self.clip_by_norm = nn.ClipByNorm() | self.clip_by_norm = nn.ClipByNorm() | ||||
@@ -253,7 +256,8 @@ class _ClipGradients(nn.Cell): | |||||
""" | """ | ||||
construct a compute flow. | construct a compute flow. | ||||
""" | """ | ||||
if clip_type not in (0, 1): | |||||
# pylint: disable=consider-using-in | |||||
if clip_type != 0 and clip_type != 1: | |||||
return grads | return grads | ||||
new_grads = () | new_grads = () | ||||
@@ -268,6 +272,18 @@ class _ClipGradients(nn.Cell): | |||||
return new_grads | return new_grads | ||||
class _TupleAdd(nn.Cell): | |||||
def __init__(self): | |||||
super(_TupleAdd, self).__init__() | |||||
self.add = P.TensorAdd() | |||||
self.hyper_map = C.HyperMap() | |||||
def construct(self, input1, input2): | |||||
"""Add two tuple of data.""" | |||||
out = self.hyper_map(self.add, input1, input2) | |||||
return out | |||||
class _TrainOneStepWithLossScaleCell(Cell): | class _TrainOneStepWithLossScaleCell(Cell): | ||||
r""" | r""" | ||||
Network training with loss scaling. | Network training with loss scaling. | ||||
@@ -347,6 +363,9 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
self._split = P.Split(0, self._micro_batches) | self._split = P.Split(0, self._micro_batches) | ||||
self._clip_by_global_norm = _ClipGradients() | self._clip_by_global_norm = _ClipGradients() | ||||
self._mech = mech | self._mech = mech | ||||
self._tuple_add = _TupleAdd() | |||||
self._hyper_map = C.HyperMap() | |||||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||||
def construct(self, data, label, sens=None): | def construct(self, data, label, sens=None): | ||||
""" | """ | ||||
@@ -368,32 +387,28 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
weights = self.weights | weights = self.weights | ||||
record_datas = self._split(data) | record_datas = self._split(data) | ||||
record_labels = self._split(label) | record_labels = self._split(label) | ||||
grads = () | |||||
# first index | # first index | ||||
loss = self.network(record_datas[0], record_labels[0]) | loss = self.network(record_datas[0], record_labels[0]) | ||||
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | ||||
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) | record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) | ||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | ||||
grad_sum = list(record_grad) | |||||
grad_len = len(record_grad) | |||||
for i in range(grad_len): | |||||
grad_sum[i] = grad_sum[i].asnumpy() | |||||
grads = record_grad | |||||
total_loss = loss | |||||
for i in range(1, self._micro_batches): | for i in range(1, self._micro_batches): | ||||
loss = self.network(record_datas[i], record_labels[i]) | loss = self.network(record_datas[i], record_labels[i]) | ||||
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | ||||
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) | record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) | ||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | ||||
for j in range(grad_len): | |||||
grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy() | |||||
grads = self._tuple_add(grads, record_grad) | |||||
total_loss = P.TensorAdd()(total_loss, loss) | |||||
loss = P.Div()(total_loss, self._micro_float) | |||||
for i in range(grad_len): | |||||
grad_sum[i] = Tensor(grad_sum[i], ms.float32) | |||||
grads = tuple(grad_sum) | |||||
loss = self.network(data, label) | |||||
if self._mech is not None: | |||||
grad_noise = self._hyper_map(self._mech, grads) | |||||
grads = self._tuple_add(grads, grad_noise) | |||||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) | |||||
# apply grad reducer on grads | # apply grad reducer on grads | ||||
grads = self.grad_reducer(grads) | grads = self.grad_reducer(grads) | ||||
# get the overflow buffer | # get the overflow buffer | ||||
@@ -474,6 +489,9 @@ class _TrainOneStepCell(Cell): | |||||
self._split = P.Split(0, self._micro_batches) | self._split = P.Split(0, self._micro_batches) | ||||
self._clip_by_global_norm = _ClipGradients() | self._clip_by_global_norm = _ClipGradients() | ||||
self._mech = mech | self._mech = mech | ||||
self._tuple_add = _TupleAdd() | |||||
self._hyper_map = C.HyperMap() | |||||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||||
def construct(self, data, label): | def construct(self, data, label): | ||||
""" | """ | ||||
@@ -486,23 +504,21 @@ class _TrainOneStepCell(Cell): | |||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | ||||
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) | record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) | ||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | ||||
grad_sum = list(record_grad) | |||||
grad_len = len(record_grad) | |||||
for i in range(grad_len): | |||||
grad_sum[i] = grad_sum[i].asnumpy() | |||||
grads = record_grad | |||||
total_loss = loss | |||||
for i in range(1, self._micro_batches): | for i in range(1, self._micro_batches): | ||||
loss = self.network(record_datas[i], record_labels[i]) | loss = self.network(record_datas[i], record_labels[i]) | ||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | ||||
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) | record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) | ||||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | ||||
for j in range(grad_len): | |||||
grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy() | |||||
for i in range(grad_len): | |||||
grad_sum[i] = Tensor(grad_sum[i], ms.float32) | |||||
grads = tuple(grad_sum) | |||||
loss = self.network(data, label) | |||||
grads = self._tuple_add(grads, record_grad) | |||||
total_loss = P.TensorAdd()(total_loss, loss) | |||||
loss = P.Div()(total_loss, self._micro_float) | |||||
if self._mech is not None: | |||||
grad_noise = self._hyper_map(self._mech, grads) | |||||
grads = self._tuple_add(grads, grad_noise) | |||||
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||||
if self.reducer_flag: | if self.reducer_flag: | ||||
# apply grad reducer on grads | # apply grad reducer on grads | ||||
@@ -18,7 +18,7 @@ from setuptools import setup | |||||
from setuptools.command.egg_info import egg_info | from setuptools.command.egg_info import egg_info | ||||
from setuptools.command.build_py import build_py | from setuptools.command.build_py import build_py | ||||
version = '0.3.0' | |||||
version = '0.5.0' | |||||
cur_dir = os.path.dirname(os.path.realpath(__file__)) | cur_dir = os.path.dirname(os.path.realpath(__file__)) | ||||
pkg_dir = os.path.join(cur_dir, 'build') | pkg_dir = os.path.join(cur_dir, 'build') | ||||
@@ -17,6 +17,8 @@ different Privacy test. | |||||
import pytest | import pytest | ||||
from mindspore import context | from mindspore import context | ||||
from mindspore import Tensor | |||||
from mindspore.common import dtype as mstype | |||||
from mindarmour.diff_privacy import GaussianRandom | from mindarmour.diff_privacy import GaussianRandom | ||||
from mindarmour.diff_privacy import AdaGaussianRandom | from mindarmour.diff_privacy import AdaGaussianRandom | ||||
from mindarmour.diff_privacy import MechanismsFactory | from mindarmour.diff_privacy import MechanismsFactory | ||||
@@ -26,13 +28,13 @@ from mindarmour.diff_privacy import MechanismsFactory | |||||
@pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_gaussian(): | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
shape = (3, 2, 4) | |||||
def test_graph_gaussian(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
grad = Tensor([3, 2, 4], mstype.float32) | |||||
norm_bound = 1.0 | norm_bound = 1.0 | ||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
net = GaussianRandom(norm_bound, initial_noise_multiplier) | net = GaussianRandom(norm_bound, initial_noise_multiplier) | ||||
res = net(shape) | |||||
res = net(grad) | |||||
print(res) | print(res) | ||||
@@ -40,42 +42,99 @@ def test_gaussian(): | |||||
@pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_ada_gaussian(): | |||||
def test_pynative_gaussian(): | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | ||||
shape = (3, 2, 4) | |||||
grad = Tensor([3, 2, 4], mstype.float32) | |||||
norm_bound = 1.0 | |||||
initial_noise_multiplier = 0.1 | |||||
net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||||
res = net(grad) | |||||
print(res) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_graph_ada_gaussian(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
grad = Tensor([3, 2, 4], mstype.float32) | |||||
norm_bound = 1.0 | norm_bound = 1.0 | ||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
noise_decay_rate = 0.5 | |||||
decay_policy = "Step" | |||||
alpha = 0.5 | |||||
decay_policy = 'Step' | |||||
net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | ||||
noise_decay_rate, decay_policy) | |||||
res = net(shape) | |||||
noise_decay_rate=alpha, decay_policy=decay_policy) | |||||
res = net(grad) | |||||
print(res) | print(res) | ||||
def test_factory(): | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
shape = (3, 2, 4) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_graph_factory(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
grad = Tensor([3, 2, 4], mstype.float32) | |||||
norm_bound = 1.0 | norm_bound = 1.0 | ||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
noise_decay_rate = 0.5 | |||||
decay_policy = "Step" | |||||
alpha = 0.5 | |||||
decay_policy = 'Step' | |||||
noise_mechanism = MechanismsFactory() | noise_mechanism = MechanismsFactory() | ||||
noise_construct = noise_mechanism.create('Gaussian', | noise_construct = noise_mechanism.create('Gaussian', | ||||
norm_bound, | norm_bound, | ||||
initial_noise_multiplier) | initial_noise_multiplier) | ||||
noise = noise_construct(shape) | |||||
noise = noise_construct(grad) | |||||
print('Gaussian noise: ', noise) | print('Gaussian noise: ', noise) | ||||
ada_mechanism = MechanismsFactory() | ada_mechanism = MechanismsFactory() | ||||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | ada_noise_construct = ada_mechanism.create('AdaGaussian', | ||||
norm_bound, | norm_bound, | ||||
initial_noise_multiplier, | initial_noise_multiplier, | ||||
noise_decay_rate, | |||||
decay_policy) | |||||
ada_noise = ada_noise_construct(shape) | |||||
noise_decay_rate=alpha, | |||||
decay_policy=decay_policy) | |||||
ada_noise = ada_noise_construct(grad) | |||||
print('ada noise: ', ada_noise) | print('ada noise: ', ada_noise) | ||||
if __name__ == '__main__': | |||||
# device_target can be "CPU", "GPU" or "Ascend" | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_pynative_ada_gaussian(): | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | ||||
grad = Tensor([3, 2, 4], mstype.float32) | |||||
norm_bound = 1.0 | |||||
initial_noise_multiplier = 0.1 | |||||
alpha = 0.5 | |||||
decay_policy = 'Step' | |||||
net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||||
noise_decay_rate=alpha, decay_policy=decay_policy) | |||||
res = net(grad) | |||||
print(res) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_pynative_factory(): | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
grad = Tensor([3, 2, 4], mstype.float32) | |||||
norm_bound = 1.0 | |||||
initial_noise_multiplier = 0.1 | |||||
alpha = 0.5 | |||||
decay_policy = 'Step' | |||||
noise_mechanism = MechanismsFactory() | |||||
noise_construct = noise_mechanism.create('Gaussian', | |||||
norm_bound, | |||||
initial_noise_multiplier) | |||||
noise = noise_construct(grad) | |||||
print('Gaussian noise: ', noise) | |||||
ada_mechanism = MechanismsFactory() | |||||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||||
norm_bound, | |||||
initial_noise_multiplier, | |||||
noise_decay_rate=alpha, | |||||
decay_policy=decay_policy) | |||||
ada_noise = ada_noise_construct(grad) | |||||
print('ada noise: ', ada_noise) |
@@ -21,13 +21,15 @@ from mindspore import nn | |||||
from mindspore import context | from mindspore import context | ||||
import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
from mindarmour.diff_privacy import DPModel | from mindarmour.diff_privacy import DPModel | ||||
from mindarmour.diff_privacy import MechanismsFactory | |||||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
from test_network import LeNet5 | from test_network import LeNet5 | ||||
def dataset_generator(batch_size, batches): | def dataset_generator(batch_size, batches): | ||||
"""mock training data.""" | |||||
data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) | data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) | ||||
label = np.random.randint(0, 10, batches * batch_size).astype(np.int32) | label = np.random.randint(0, 10, batches * batch_size).astype(np.int32) | ||||
for i in range(batches): | for i in range(batches): | ||||
@@ -39,7 +41,7 @@ def dataset_generator(batch_size, batches): | |||||
@pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
@pytest.mark.env_card | @pytest.mark.env_card | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_dp_model(): | |||||
def test_dp_model_pynative_mode(): | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | ||||
l2_norm_bound = 1.0 | l2_norm_bound = 1.0 | ||||
initial_noise_multiplier = 0.01 | initial_noise_multiplier = 0.01 | ||||
@@ -47,21 +49,50 @@ def test_dp_model(): | |||||
batch_size = 32 | batch_size = 32 | ||||
batches = 128 | batches = 128 | ||||
epochs = 1 | epochs = 1 | ||||
micro_batches = 2 | |||||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches) | |||||
factory_opt.set_mechanisms('Gaussian', | |||||
norm_bound=l2_norm_bound, | |||||
initial_noise_multiplier=initial_noise_multiplier) | |||||
net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
model = DPModel(micro_batches=micro_batches, | |||||
norm_clip=l2_norm_bound, | |||||
mech=None, | |||||
network=network, | |||||
loss_fn=loss, | |||||
optimizer=net_opt, | |||||
metrics=None) | |||||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||||
ms_ds.set_dataset_size(batch_size * batches) | |||||
model.train(epochs, ms_ds, dataset_sink_mode=False) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_dp_model_with_graph_mode(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
l2_norm_bound = 1.0 | |||||
initial_noise_multiplier = 0.01 | |||||
network = LeNet5() | |||||
batch_size = 32 | |||||
batches = 128 | |||||
epochs = 1 | |||||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | ||||
gaussian_mech = DPOptimizerClassFactory(micro_batches=2) | |||||
gaussian_mech.set_mechanisms('Gaussian', | |||||
norm_bound=l2_norm_bound, | |||||
initial_noise_multiplier=initial_noise_multiplier) | |||||
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), | |||||
learning_rate=0.1, | |||||
momentum=0.9) | |||||
mech = MechanismsFactory().create('Gaussian', | |||||
norm_bound=l2_norm_bound, | |||||
initial_noise_multiplier=initial_noise_multiplier) | |||||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
model = DPModel(micro_batches=2, | model = DPModel(micro_batches=2, | ||||
norm_clip=l2_norm_bound, | norm_clip=l2_norm_bound, | ||||
dp_mech=gaussian_mech.mech, | |||||
mech=mech, | |||||
network=network, | network=network, | ||||
loss_fn=loss, | loss_fn=loss, | ||||
optimizer=net_opt, | optimizer=net_opt, | ||||
metrics=None) | metrics=None) | ||||
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | ||||
ms_ds.set_dataset_size(batch_size * batches) | ms_ds.set_dataset_size(batch_size * batches) | ||||
model.train(epochs, ms_ds) | |||||
model.train(epochs, ms_ds, dataset_sink_mode=False) |
@@ -21,6 +21,7 @@ from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
from test_network import LeNet5 | from test_network import LeNet5 | ||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
@pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||