Merge pull request !33 from jxlang910/mastertags/v0.5.0-beta
@@ -19,14 +19,22 @@ network config setting, will be used in train.py | |||
from easydict import EasyDict as edict | |||
mnist_cfg = edict({ | |||
'num_classes': 10, | |||
'lr': 0.01, | |||
'momentum': 0.9, | |||
'epoch_size': 10, | |||
'batch_size': 256, | |||
'buffer_size': 1000, | |||
'image_height': 32, | |||
'image_width': 32, | |||
'save_checkpoint_steps': 234, | |||
'keep_checkpoint_max': 10, | |||
'num_classes': 10, # the number of classes of model's output | |||
'lr': 0.01, # the learning rate of model's optimizer | |||
'momentum': 0.9, # the momentum value of model's optimizer | |||
'epoch_size': 10, # training epochs | |||
'batch_size': 256, # batch size for training | |||
'image_height': 32, # the height of training samples | |||
'image_width': 32, # the width of training samples | |||
'save_checkpoint_steps': 234, # the interval steps for saving checkpoint file of the model | |||
'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved | |||
'device_target': 'Ascend', # device used | |||
'data_path': './MNIST_unzip', # the path of training and testing data set | |||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||
'micro_batches': 32, # the number of small batches split from an original batch | |||
'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | |||
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training | |||
# parameters' gradients | |||
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | |||
}) |
@@ -15,7 +15,6 @@ | |||
python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2 | |||
""" | |||
import os | |||
import argparse | |||
import mindspore.nn as nn | |||
from mindspore import context | |||
@@ -87,21 +86,7 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||
help='device where the code will be implemented (default: Ascend)') | |||
parser.add_argument('--data_path', type=str, default="./MNIST_unzip", | |||
help='path where the dataset is saved') | |||
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') | |||
parser.add_argument('--micro_batches', type=int, default=32, | |||
help='optional, if use differential privacy, need to set micro_batches') | |||
parser.add_argument('--l2_norm_bound', type=float, default=1.0, | |||
help='optional, if use differential privacy, need to set l2_norm_bound') | |||
parser.add_argument('--initial_noise_multiplier', type=float, default=1.5, | |||
help='optional, if use differential privacy, need to set initial_noise_multiplier') | |||
args = parser.parse_args() | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target) | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | |||
network = LeNet5() | |||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
@@ -111,27 +96,41 @@ if __name__ == "__main__": | |||
directory='./trained_ckpt_file/', | |||
config=config_ck) | |||
ds_train = generate_mnist_dataset(os.path.join(args.data_path, "train"), | |||
# get training dataset | |||
ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), | |||
cfg.batch_size, | |||
cfg.epoch_size) | |||
if args.micro_batches and cfg.batch_size % args.micro_batches != 0: | |||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||
raise ValueError("Number of micro_batches should divide evenly batch_size") | |||
gaussian_mech = DPOptimizerClassFactory(args.micro_batches) | |||
gaussian_mech.set_mechanisms('Gaussian', | |||
norm_bound=args.l2_norm_bound, | |||
initial_noise_multiplier=args.initial_noise_multiplier) | |||
net_opt = gaussian_mech.create('Momentum')(params=network.trainable_params(), | |||
learning_rate=cfg.lr, | |||
momentum=cfg.momentum) | |||
# Create a factory class of DP optimizer | |||
gaussian_mech = DPOptimizerClassFactory(cfg.micro_batches) | |||
# Set the method of adding noise in gradients while training. Initial_noise_multiplier is suggested to be greater | |||
# than 1.0, otherwise the privacy budget would be huge, which means that the privacy protection effect is weak. | |||
# mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' mechanism while | |||
# be constant with 'Gaussian' mechanism. | |||
gaussian_mech.set_mechanisms(cfg.mechanisms, | |||
norm_bound=cfg.l2_norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
# Wrap the base optimizer for DP training. Momentum optimizer is suggested for LenNet5. | |||
net_opt = gaussian_mech.create(cfg.optimizer)(params=network.trainable_params(), | |||
learning_rate=cfg.lr, | |||
momentum=cfg.momentum) | |||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||
# and delta) while training. | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
initial_noise_multiplier=args.initial_noise_multiplier* | |||
args.l2_norm_bound, | |||
per_print_times=10) | |||
model = DPModel(micro_batches=args.micro_batches, | |||
norm_clip=args.l2_norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier* | |||
cfg.l2_norm_bound, | |||
per_print_times=50) | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_clip=cfg.l2_norm_bound, | |||
dp_mech=gaussian_mech.mech, | |||
network=network, | |||
loss_fn=net_loss, | |||
@@ -140,12 +139,12 @@ if __name__ == "__main__": | |||
LOGGER.info(TAG, "============== Starting Training ==============") | |||
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | |||
dataset_sink_mode=args.dataset_sink_mode) | |||
dataset_sink_mode=cfg.dataset_sink_mode) | |||
LOGGER.info(TAG, "============== Starting Testing ==============") | |||
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | |||
param_dict = load_checkpoint(ckpt_file_name) | |||
load_param_into_net(network, param_dict) | |||
ds_eval = generate_mnist_dataset(os.path.join(args.data_path, 'test'), batch_size=cfg.batch_size) | |||
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) | |||
acc = model.eval(ds_eval, dataset_sink_mode=False) | |||
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) |
@@ -24,6 +24,7 @@ from mindspore.common import dtype as mstype | |||
from mindarmour.utils._check_param import check_param_type | |||
from mindarmour.utils._check_param import check_value_positive | |||
from mindarmour.utils._check_param import check_param_in_range | |||
class MechanismsFactory: | |||
@@ -37,7 +38,8 @@ class MechanismsFactory: | |||
""" | |||
Args: | |||
policy(str): Noise generated strategy, could be 'Gaussian' or | |||
'AdaGaussian'. Default: 'AdaGaussian'. | |||
'AdaGaussian'. Noise would be decayed with 'AdaGaussian' mechanism while | |||
be constant with 'Gaussian' mechanism. Default: 'AdaGaussian'. | |||
args(Union[float, str]): Parameters used for creating noise | |||
mechanisms. | |||
kwargs(Union[float, str]): Parameters used for creating noise | |||
@@ -115,7 +117,8 @@ class GaussianRandom(Mechanisms): | |||
class AdaGaussianRandom(Mechanisms): | |||
""" | |||
Adaptive Gaussian noise generated mechanism. | |||
Adaptive Gaussian noise generated mechanism. Noise would be decayed with training. Decay mode could be 'Time' | |||
mode or 'Step' mode. | |||
Args: | |||
norm_bound(float): Clipping bound for the l2 norm of the gradients. | |||
@@ -123,7 +126,7 @@ class AdaGaussianRandom(Mechanisms): | |||
initial_noise_multiplier(float): Ratio of the standard deviation of | |||
Gaussian noise divided by the norm_bound, which will be used to | |||
calculate privacy spent. Default: 5.0. | |||
alpha(float): Hyperparameter for controlling the noise decay. | |||
noise_decay_rate(float): Hyperparameter for controlling the noise decay. | |||
Default: 6e-4. | |||
decay_policy(str): Noise decay strategy include 'Step' and 'Time'. | |||
Default: 'Time'. | |||
@@ -135,16 +138,16 @@ class AdaGaussianRandom(Mechanisms): | |||
>>> shape = (3, 2, 4) | |||
>>> norm_bound = 1.0 | |||
>>> initial_noise_multiplier = 0.1 | |||
>>> alpha = 0.5 | |||
>>> noise_decay_rate = 0.5 | |||
>>> decay_policy = "Time" | |||
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||
>>> alpha, decay_policy) | |||
>>> noise_decay_rate, decay_policy) | |||
>>> res = net(shape) | |||
>>> print(res) | |||
""" | |||
def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, | |||
alpha=6e-4, decay_policy='Time'): | |||
noise_decay_rate=6e-4, decay_policy='Time'): | |||
super(AdaGaussianRandom, self).__init__() | |||
initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | |||
initial_noise_multiplier) | |||
@@ -156,8 +159,9 @@ class AdaGaussianRandom(Mechanisms): | |||
norm_bound = check_value_positive('norm_bound', norm_bound) | |||
self._norm_bound = Tensor(np.array(norm_bound, np.float32)) | |||
alpha = check_param_type('alpha', alpha, float) | |||
self._alpha = Tensor(np.array(alpha, np.float32)) | |||
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | |||
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | |||
self._noise_decay_rate = Tensor(np.array(noise_decay_rate, np.float32)) | |||
if decay_policy not in ['Time', 'Step']: | |||
raise NameError("The decay_policy must be in ['Time', 'Step'], but " | |||
@@ -176,12 +180,12 @@ class AdaGaussianRandom(Mechanisms): | |||
if self._decay_policy == 'Time': | |||
temp = self._div(self._initial_noise_multiplier, | |||
self._noise_multiplier) | |||
temp = self._add(temp, self._alpha) | |||
temp = self._add(temp, self._noise_decay_rate) | |||
temp = self._div(self._initial_noise_multiplier, temp) | |||
self._noise_multiplier = Parameter(temp, name='noise_multiplier') | |||
else: | |||
one = Tensor(1, self._dtype) | |||
temp = self._sub(one, self._alpha) | |||
temp = self._sub(one, self._noise_decay_rate) | |||
temp = self._mul(temp, self._noise_multiplier) | |||
self._noise_multiplier = Parameter(temp, name='noise_multiplier') | |||
@@ -20,7 +20,7 @@ from mindspore.train.callback import Callback | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils._check_param import check_int_positive, \ | |||
check_value_positive | |||
check_value_positive, check_param_in_range, check_param_type | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'DP monitor' | |||
@@ -40,7 +40,8 @@ class PrivacyMonitorFactory: | |||
Create a privacy monitor class. | |||
Args: | |||
policy (str): Monitor policy, 'rdp' is supported by now. | |||
policy (str): Monitor policy, 'rdp' is supported by now. RDP means R'enyi differential privacy, | |||
which computed based on R'enyi divergence. | |||
args (Union[int, float, numpy.ndarray, list, str]): Parameters | |||
used for creating a privacy monitor. | |||
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword | |||
@@ -70,7 +71,7 @@ class RDPMonitor(Callback): | |||
num_samples (int): The total number of samples in training data sets. | |||
batch_size (int): The number of samples in a batch while training. | |||
initial_noise_multiplier (Union[float, int]): The initial | |||
multiplier of added noise. Default: 1.5. | |||
multiplier of the noise added to training parameters' gradients. Default: 1.5. | |||
max_eps (Union[float, int, None]): The maximum acceptable epsilon | |||
budget for DP training. Default: 10.0. | |||
target_delta (Union[float, int, None]): Target delta budget for DP | |||
@@ -137,11 +138,8 @@ class RDPMonitor(Callback): | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
if noise_decay_rate is not None: | |||
check_value_positive('noise_decay_rate', noise_decay_rate) | |||
if noise_decay_rate >= 1: | |||
msg = 'Noise decay rate must be less than 1' | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | |||
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | |||
check_int_positive('per_print_times', per_print_times) | |||
self._total_echo_privacy = None | |||
@@ -27,7 +27,7 @@ class DPOptimizerClassFactory: | |||
Factory class of Optimizer. | |||
Args: | |||
micro_batches (int): The number of small batches split from an origianl batch. Default: 2. | |||
micro_batches (int): The number of small batches split from an original batch. Default: 2. | |||
Returns: | |||
Optimizer, Optimizer class | |||
@@ -70,7 +70,7 @@ class DPModel(Model): | |||
This class is overload mindspore.train.model.Model. | |||
Args: | |||
micro_batches (int): The number of small batches split from an origianl batch. Default: 2. | |||
micro_batches (int): The number of small batches split from an original batch. Default: 2. | |||
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | |||
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||
@@ -45,10 +45,10 @@ def test_ada_gaussian(): | |||
shape = (3, 2, 4) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
noise_decay_rate = 0.5 | |||
decay_policy = "Step" | |||
net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||
alpha, decay_policy) | |||
noise_decay_rate, decay_policy) | |||
res = net(shape) | |||
print(res) | |||
@@ -58,7 +58,7 @@ def test_factory(): | |||
shape = (3, 2, 4) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
noise_decay_rate = 0.5 | |||
decay_policy = "Step" | |||
noise_mechanism = MechanismsFactory() | |||
noise_construct = noise_mechanism.create('Gaussian', | |||
@@ -70,7 +70,7 @@ def test_factory(): | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
alpha, | |||
noise_decay_rate, | |||
decay_policy) | |||
ada_noise = ada_noise_construct(shape) | |||
print('ada noise: ', ada_noise) | |||