@@ -19,14 +19,22 @@ network config setting, will be used in train.py | |||||
from easydict import EasyDict as edict | from easydict import EasyDict as edict | ||||
mnist_cfg = edict({ | mnist_cfg = edict({ | ||||
'num_classes': 10, | |||||
'lr': 0.01, | |||||
'momentum': 0.9, | |||||
'epoch_size': 10, | |||||
'batch_size': 256, | |||||
'buffer_size': 1000, | |||||
'image_height': 32, | |||||
'image_width': 32, | |||||
'save_checkpoint_steps': 234, | |||||
'keep_checkpoint_max': 10, | |||||
'num_classes': 10, # the number of classes of model's output | |||||
'lr': 0.01, # the learning rate of model's optimizer | |||||
'momentum': 0.9, # the momentum value of model's optimizer | |||||
'epoch_size': 10, # training epochs | |||||
'batch_size': 256, # batch size for training | |||||
'image_height': 32, # the height of training samples | |||||
'image_width': 32, # the width of training samples | |||||
'save_checkpoint_steps': 234, # the interval steps for saving checkpoint file of the model | |||||
'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved | |||||
'device_target': 'Ascend', # device used | |||||
'data_path': './MNIST_unzip', # the path of training and testing data set | |||||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||||
'micro_batches': 32, # the number of small batches split from an original batch | |||||
'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | |||||
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training | |||||
# parameters' gradients | |||||
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||||
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | |||||
}) | }) |
@@ -15,7 +15,6 @@ | |||||
python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2 | python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2 | ||||
""" | """ | ||||
import os | import os | ||||
import argparse | |||||
import mindspore.nn as nn | import mindspore.nn as nn | ||||
from mindspore import context | from mindspore import context | ||||
@@ -87,21 +86,7 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||||
help='device where the code will be implemented (default: Ascend)') | |||||
parser.add_argument('--data_path', type=str, default="./MNIST_unzip", | |||||
help='path where the dataset is saved') | |||||
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') | |||||
parser.add_argument('--micro_batches', type=int, default=32, | |||||
help='optional, if use differential privacy, need to set micro_batches') | |||||
parser.add_argument('--l2_norm_bound', type=float, default=1.0, | |||||
help='optional, if use differential privacy, need to set l2_norm_bound') | |||||
parser.add_argument('--initial_noise_multiplier', type=float, default=1.5, | |||||
help='optional, if use differential privacy, need to set initial_noise_multiplier') | |||||
args = parser.parse_args() | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target) | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | |||||
network = LeNet5() | network = LeNet5() | ||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | ||||
@@ -111,27 +96,41 @@ if __name__ == "__main__": | |||||
directory='./trained_ckpt_file/', | directory='./trained_ckpt_file/', | ||||
config=config_ck) | config=config_ck) | ||||
ds_train = generate_mnist_dataset(os.path.join(args.data_path, "train"), | |||||
# get training dataset | |||||
ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), | |||||
cfg.batch_size, | cfg.batch_size, | ||||
cfg.epoch_size) | cfg.epoch_size) | ||||
if args.micro_batches and cfg.batch_size % args.micro_batches != 0: | |||||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||||
raise ValueError("Number of micro_batches should divide evenly batch_size") | raise ValueError("Number of micro_batches should divide evenly batch_size") | ||||
gaussian_mech = DPOptimizerClassFactory(args.micro_batches) | |||||
gaussian_mech.set_mechanisms('Gaussian', | |||||
norm_bound=args.l2_norm_bound, | |||||
initial_noise_multiplier=args.initial_noise_multiplier) | |||||
net_opt = gaussian_mech.create('Momentum')(params=network.trainable_params(), | |||||
learning_rate=cfg.lr, | |||||
momentum=cfg.momentum) | |||||
# Create a factory class of DP optimizer | |||||
gaussian_mech = DPOptimizerClassFactory(cfg.micro_batches) | |||||
# Set the method of adding noise in gradients while training. Initial_noise_multiplier is suggested to be greater | |||||
# than 1.0, otherwise the privacy budget would be huge, which means that the privacy protection effect is weak. | |||||
# mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' mechanism while | |||||
# be constant with 'Gaussian' mechanism. | |||||
gaussian_mech.set_mechanisms(cfg.mechanisms, | |||||
norm_bound=cfg.l2_norm_bound, | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||||
# Wrap the base optimizer for DP training. Momentum optimizer is suggested for LenNet5. | |||||
net_opt = gaussian_mech.create(cfg.optimizer)(params=network.trainable_params(), | |||||
learning_rate=cfg.lr, | |||||
momentum=cfg.momentum) | |||||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||||
# and delta) while training. | |||||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | rdp_monitor = PrivacyMonitorFactory.create('rdp', | ||||
num_samples=60000, | num_samples=60000, | ||||
batch_size=cfg.batch_size, | batch_size=cfg.batch_size, | ||||
initial_noise_multiplier=args.initial_noise_multiplier* | |||||
args.l2_norm_bound, | |||||
per_print_times=10) | |||||
model = DPModel(micro_batches=args.micro_batches, | |||||
norm_clip=args.l2_norm_bound, | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier* | |||||
cfg.l2_norm_bound, | |||||
per_print_times=50) | |||||
# Create the DP model for training. | |||||
model = DPModel(micro_batches=cfg.micro_batches, | |||||
norm_clip=cfg.l2_norm_bound, | |||||
dp_mech=gaussian_mech.mech, | dp_mech=gaussian_mech.mech, | ||||
network=network, | network=network, | ||||
loss_fn=net_loss, | loss_fn=net_loss, | ||||
@@ -140,12 +139,12 @@ if __name__ == "__main__": | |||||
LOGGER.info(TAG, "============== Starting Training ==============") | LOGGER.info(TAG, "============== Starting Training ==============") | ||||
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | ||||
dataset_sink_mode=args.dataset_sink_mode) | |||||
dataset_sink_mode=cfg.dataset_sink_mode) | |||||
LOGGER.info(TAG, "============== Starting Testing ==============") | LOGGER.info(TAG, "============== Starting Testing ==============") | ||||
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | ||||
param_dict = load_checkpoint(ckpt_file_name) | param_dict = load_checkpoint(ckpt_file_name) | ||||
load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
ds_eval = generate_mnist_dataset(os.path.join(args.data_path, 'test'), batch_size=cfg.batch_size) | |||||
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) | |||||
acc = model.eval(ds_eval, dataset_sink_mode=False) | acc = model.eval(ds_eval, dataset_sink_mode=False) | ||||
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) | LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) |
@@ -24,6 +24,7 @@ from mindspore.common import dtype as mstype | |||||
from mindarmour.utils._check_param import check_param_type | from mindarmour.utils._check_param import check_param_type | ||||
from mindarmour.utils._check_param import check_value_positive | from mindarmour.utils._check_param import check_value_positive | ||||
from mindarmour.utils._check_param import check_param_in_range | |||||
class MechanismsFactory: | class MechanismsFactory: | ||||
@@ -37,7 +38,8 @@ class MechanismsFactory: | |||||
""" | """ | ||||
Args: | Args: | ||||
policy(str): Noise generated strategy, could be 'Gaussian' or | policy(str): Noise generated strategy, could be 'Gaussian' or | ||||
'AdaGaussian'. Default: 'AdaGaussian'. | |||||
'AdaGaussian'. Noise would be decayed with 'AdaGaussian' mechanism while | |||||
be constant with 'Gaussian' mechanism. Default: 'AdaGaussian'. | |||||
args(Union[float, str]): Parameters used for creating noise | args(Union[float, str]): Parameters used for creating noise | ||||
mechanisms. | mechanisms. | ||||
kwargs(Union[float, str]): Parameters used for creating noise | kwargs(Union[float, str]): Parameters used for creating noise | ||||
@@ -115,7 +117,8 @@ class GaussianRandom(Mechanisms): | |||||
class AdaGaussianRandom(Mechanisms): | class AdaGaussianRandom(Mechanisms): | ||||
""" | """ | ||||
Adaptive Gaussian noise generated mechanism. | |||||
Adaptive Gaussian noise generated mechanism. Noise would be decayed with training. Decay mode could be 'Time' | |||||
mode or 'Step' mode. | |||||
Args: | Args: | ||||
norm_bound(float): Clipping bound for the l2 norm of the gradients. | norm_bound(float): Clipping bound for the l2 norm of the gradients. | ||||
@@ -123,7 +126,7 @@ class AdaGaussianRandom(Mechanisms): | |||||
initial_noise_multiplier(float): Ratio of the standard deviation of | initial_noise_multiplier(float): Ratio of the standard deviation of | ||||
Gaussian noise divided by the norm_bound, which will be used to | Gaussian noise divided by the norm_bound, which will be used to | ||||
calculate privacy spent. Default: 5.0. | calculate privacy spent. Default: 5.0. | ||||
alpha(float): Hyperparameter for controlling the noise decay. | |||||
noise_decay_rate(float): Hyperparameter for controlling the noise decay. | |||||
Default: 6e-4. | Default: 6e-4. | ||||
decay_policy(str): Noise decay strategy include 'Step' and 'Time'. | decay_policy(str): Noise decay strategy include 'Step' and 'Time'. | ||||
Default: 'Time'. | Default: 'Time'. | ||||
@@ -135,16 +138,16 @@ class AdaGaussianRandom(Mechanisms): | |||||
>>> shape = (3, 2, 4) | >>> shape = (3, 2, 4) | ||||
>>> norm_bound = 1.0 | >>> norm_bound = 1.0 | ||||
>>> initial_noise_multiplier = 0.1 | >>> initial_noise_multiplier = 0.1 | ||||
>>> alpha = 0.5 | |||||
>>> noise_decay_rate = 0.5 | |||||
>>> decay_policy = "Time" | >>> decay_policy = "Time" | ||||
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | >>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | ||||
>>> alpha, decay_policy) | |||||
>>> noise_decay_rate, decay_policy) | |||||
>>> res = net(shape) | >>> res = net(shape) | ||||
>>> print(res) | >>> print(res) | ||||
""" | """ | ||||
def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, | def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, | ||||
alpha=6e-4, decay_policy='Time'): | |||||
noise_decay_rate=6e-4, decay_policy='Time'): | |||||
super(AdaGaussianRandom, self).__init__() | super(AdaGaussianRandom, self).__init__() | ||||
initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | ||||
initial_noise_multiplier) | initial_noise_multiplier) | ||||
@@ -156,8 +159,9 @@ class AdaGaussianRandom(Mechanisms): | |||||
norm_bound = check_value_positive('norm_bound', norm_bound) | norm_bound = check_value_positive('norm_bound', norm_bound) | ||||
self._norm_bound = Tensor(np.array(norm_bound, np.float32)) | self._norm_bound = Tensor(np.array(norm_bound, np.float32)) | ||||
alpha = check_param_type('alpha', alpha, float) | |||||
self._alpha = Tensor(np.array(alpha, np.float32)) | |||||
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | |||||
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | |||||
self._noise_decay_rate = Tensor(np.array(noise_decay_rate, np.float32)) | |||||
if decay_policy not in ['Time', 'Step']: | if decay_policy not in ['Time', 'Step']: | ||||
raise NameError("The decay_policy must be in ['Time', 'Step'], but " | raise NameError("The decay_policy must be in ['Time', 'Step'], but " | ||||
@@ -176,12 +180,12 @@ class AdaGaussianRandom(Mechanisms): | |||||
if self._decay_policy == 'Time': | if self._decay_policy == 'Time': | ||||
temp = self._div(self._initial_noise_multiplier, | temp = self._div(self._initial_noise_multiplier, | ||||
self._noise_multiplier) | self._noise_multiplier) | ||||
temp = self._add(temp, self._alpha) | |||||
temp = self._add(temp, self._noise_decay_rate) | |||||
temp = self._div(self._initial_noise_multiplier, temp) | temp = self._div(self._initial_noise_multiplier, temp) | ||||
self._noise_multiplier = Parameter(temp, name='noise_multiplier') | self._noise_multiplier = Parameter(temp, name='noise_multiplier') | ||||
else: | else: | ||||
one = Tensor(1, self._dtype) | one = Tensor(1, self._dtype) | ||||
temp = self._sub(one, self._alpha) | |||||
temp = self._sub(one, self._noise_decay_rate) | |||||
temp = self._mul(temp, self._noise_multiplier) | temp = self._mul(temp, self._noise_multiplier) | ||||
self._noise_multiplier = Parameter(temp, name='noise_multiplier') | self._noise_multiplier = Parameter(temp, name='noise_multiplier') | ||||
@@ -20,7 +20,7 @@ from mindspore.train.callback import Callback | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from mindarmour.utils._check_param import check_int_positive, \ | from mindarmour.utils._check_param import check_int_positive, \ | ||||
check_value_positive | |||||
check_value_positive, check_param_in_range, check_param_type | |||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
TAG = 'DP monitor' | TAG = 'DP monitor' | ||||
@@ -40,7 +40,8 @@ class PrivacyMonitorFactory: | |||||
Create a privacy monitor class. | Create a privacy monitor class. | ||||
Args: | Args: | ||||
policy (str): Monitor policy, 'rdp' is supported by now. | |||||
policy (str): Monitor policy, 'rdp' is supported by now. RDP means R'enyi differential privacy, | |||||
which computed based on R'enyi divergence. | |||||
args (Union[int, float, numpy.ndarray, list, str]): Parameters | args (Union[int, float, numpy.ndarray, list, str]): Parameters | ||||
used for creating a privacy monitor. | used for creating a privacy monitor. | ||||
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword | kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword | ||||
@@ -70,7 +71,7 @@ class RDPMonitor(Callback): | |||||
num_samples (int): The total number of samples in training data sets. | num_samples (int): The total number of samples in training data sets. | ||||
batch_size (int): The number of samples in a batch while training. | batch_size (int): The number of samples in a batch while training. | ||||
initial_noise_multiplier (Union[float, int]): The initial | initial_noise_multiplier (Union[float, int]): The initial | ||||
multiplier of added noise. Default: 1.5. | |||||
multiplier of the noise added to training parameters' gradients. Default: 1.5. | |||||
max_eps (Union[float, int, None]): The maximum acceptable epsilon | max_eps (Union[float, int, None]): The maximum acceptable epsilon | ||||
budget for DP training. Default: 10.0. | budget for DP training. Default: 10.0. | ||||
target_delta (Union[float, int, None]): Target delta budget for DP | target_delta (Union[float, int, None]): Target delta budget for DP | ||||
@@ -137,11 +138,8 @@ class RDPMonitor(Callback): | |||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
if noise_decay_rate is not None: | if noise_decay_rate is not None: | ||||
check_value_positive('noise_decay_rate', noise_decay_rate) | |||||
if noise_decay_rate >= 1: | |||||
msg = 'Noise decay rate must be less than 1' | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | |||||
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | |||||
check_int_positive('per_print_times', per_print_times) | check_int_positive('per_print_times', per_print_times) | ||||
self._total_echo_privacy = None | self._total_echo_privacy = None | ||||
@@ -27,7 +27,7 @@ class DPOptimizerClassFactory: | |||||
Factory class of Optimizer. | Factory class of Optimizer. | ||||
Args: | Args: | ||||
micro_batches (int): The number of small batches split from an origianl batch. Default: 2. | |||||
micro_batches (int): The number of small batches split from an original batch. Default: 2. | |||||
Returns: | Returns: | ||||
Optimizer, Optimizer class | Optimizer, Optimizer class | ||||
@@ -70,7 +70,7 @@ class DPModel(Model): | |||||
This class is overload mindspore.train.model.Model. | This class is overload mindspore.train.model.Model. | ||||
Args: | Args: | ||||
micro_batches (int): The number of small batches split from an origianl batch. Default: 2. | |||||
micro_batches (int): The number of small batches split from an original batch. Default: 2. | |||||
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | ||||
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. | dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. | ||||
@@ -45,10 +45,10 @@ def test_ada_gaussian(): | |||||
shape = (3, 2, 4) | shape = (3, 2, 4) | ||||
norm_bound = 1.0 | norm_bound = 1.0 | ||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
alpha = 0.5 | |||||
noise_decay_rate = 0.5 | |||||
decay_policy = "Step" | decay_policy = "Step" | ||||
net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | ||||
alpha, decay_policy) | |||||
noise_decay_rate, decay_policy) | |||||
res = net(shape) | res = net(shape) | ||||
print(res) | print(res) | ||||
@@ -58,7 +58,7 @@ def test_factory(): | |||||
shape = (3, 2, 4) | shape = (3, 2, 4) | ||||
norm_bound = 1.0 | norm_bound = 1.0 | ||||
initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
alpha = 0.5 | |||||
noise_decay_rate = 0.5 | |||||
decay_policy = "Step" | decay_policy = "Step" | ||||
noise_mechanism = MechanismsFactory() | noise_mechanism = MechanismsFactory() | ||||
noise_construct = noise_mechanism.create('Gaussian', | noise_construct = noise_mechanism.create('Gaussian', | ||||
@@ -70,7 +70,7 @@ def test_factory(): | |||||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | ada_noise_construct = ada_mechanism.create('AdaGaussian', | ||||
norm_bound, | norm_bound, | ||||
initial_noise_multiplier, | initial_noise_multiplier, | ||||
alpha, | |||||
noise_decay_rate, | |||||
decay_policy) | decay_policy) | ||||
ada_noise = ada_noise_construct(shape) | ada_noise = ada_noise_construct(shape) | ||||
print('ada noise: ', ada_noise) | print('ada noise: ', ada_noise) | ||||