@@ -0,0 +1,40 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
network config setting, will be used in train.py | |||
""" | |||
from easydict import EasyDict as edict | |||
mnist_cfg = edict({ | |||
'num_classes': 10, # the number of classes of model's output | |||
'lr': 0.01, # the learning rate of model's optimizer | |||
'momentum': 0.9, # the momentum value of model's optimizer | |||
'epoch_size': 10, # training epochs | |||
'batch_size': 256, # batch size for training | |||
'image_height': 32, # the height of training samples | |||
'image_width': 32, # the width of training samples | |||
'save_checkpoint_steps': 234, # the interval steps for saving checkpoint file of the model | |||
'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved | |||
'device_target': 'Ascend', # device used | |||
'data_path': './MNIST_unzip', # the path of training and testing data set | |||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||
'micro_batches': 16, # the number of small batches split from an original batch | |||
'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | |||
'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training | |||
# parameters' gradients | |||
'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | |||
}) |
@@ -33,9 +33,9 @@ mnist_cfg = edict({ | |||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||
'micro_batches': 16, # the number of small batches split from an original batch | |||
'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | |||
'initial_noise_multiplier': 0.5, # the initial multiplication coefficient of the noise added to training | |||
'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training | |||
# parameters' gradients | |||
'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||
'noise_mechanisms': 'Gaussian', # the method of adding noise in gradients while training | |||
'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training | |||
'clip_decay_policy': 'Linear', # Decay policy of adaptive clipping, decay_policy must be in ['Linear', 'Geometric']. | |||
'clip_learning_rate': 0.001, # Learning rate of update norm clip. | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
python lenet5_dp.py --data_path /YourDataPath --micro_batches=2 | |||
Training example of adaClip-mechanism differential privacy. | |||
""" | |||
import os | |||
@@ -102,8 +102,7 @@ if __name__ == "__main__": | |||
# get training dataset | |||
ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), | |||
cfg.batch_size, | |||
cfg.epoch_size) | |||
cfg.batch_size) | |||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||
raise ValueError( | |||
@@ -117,7 +116,7 @@ if __name__ == "__main__": | |||
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, | |||
norm_bound=cfg.norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier, | |||
noise_update='Exp') | |||
noise_update=None) | |||
# Create a factory class of clip mechanisms, this method is to adaptive clip | |||
# gradients while training, decay_policy support 'Linear' and 'Geometric', | |||
# learning_rate is the learning rate to update clip_norm, | |||
@@ -137,8 +136,9 @@ if __name__ == "__main__": | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.norm_bound, | |||
per_print_times=10) | |||
initial_noise_multiplier=cfg.initial_noise_multiplier, | |||
per_print_times=234, | |||
noise_decay_mode=None) | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_bound=cfg.norm_bound, | |||
@@ -0,0 +1,150 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Training example of adaGaussian-mechanism differential privacy. | |||
""" | |||
import os | |||
import mindspore.nn as nn | |||
from mindspore import context | |||
from mindspore.train.callback import ModelCheckpoint | |||
from mindspore.train.callback import CheckpointConfig | |||
from mindspore.train.callback import LossMonitor | |||
from mindspore.nn.metrics import Accuracy | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
import mindspore.dataset as ds | |||
import mindspore.dataset.transforms.vision.c_transforms as CV | |||
import mindspore.dataset.transforms.c_transforms as C | |||
from mindspore.dataset.transforms.vision import Inter | |||
import mindspore.common.dtype as mstype | |||
from mindarmour.diff_privacy import DPModel | |||
from mindarmour.diff_privacy import PrivacyMonitorFactory | |||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||
from mindarmour.utils.logger import LogUtil | |||
from lenet5_net import LeNet5 | |||
from lenet5_config import mnist_cfg as cfg | |||
LOGGER = LogUtil.get_instance() | |||
LOGGER.set_level('INFO') | |||
TAG = 'Lenet5_train' | |||
def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||
num_parallel_workers=1, sparse=True): | |||
""" | |||
create dataset for training or testing | |||
""" | |||
# define dataset | |||
ds1 = ds.MnistDataset(data_path) | |||
# define operation parameters | |||
resize_height, resize_width = 32, 32 | |||
rescale = 1.0 / 255.0 | |||
shift = 0.0 | |||
# define map operations | |||
resize_op = CV.Resize((resize_height, resize_width), | |||
interpolation=Inter.LINEAR) | |||
rescale_op = CV.Rescale(rescale, shift) | |||
hwc2chw_op = CV.HWC2CHW() | |||
type_cast_op = C.TypeCast(mstype.int32) | |||
# apply map operations on images | |||
if not sparse: | |||
one_hot_enco = C.OneHot(10) | |||
ds1 = ds1.map(input_columns="label", operations=one_hot_enco, | |||
num_parallel_workers=num_parallel_workers) | |||
type_cast_op = C.TypeCast(mstype.float32) | |||
ds1 = ds1.map(input_columns="label", operations=type_cast_op, | |||
num_parallel_workers=num_parallel_workers) | |||
ds1 = ds1.map(input_columns="image", operations=resize_op, | |||
num_parallel_workers=num_parallel_workers) | |||
ds1 = ds1.map(input_columns="image", operations=rescale_op, | |||
num_parallel_workers=num_parallel_workers) | |||
ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, | |||
num_parallel_workers=num_parallel_workers) | |||
# apply DatasetOps | |||
buffer_size = 10000 | |||
ds1 = ds1.shuffle(buffer_size=buffer_size) | |||
ds1 = ds1.batch(batch_size, drop_remainder=True) | |||
ds1 = ds1.repeat(repeat_size) | |||
return ds1 | |||
if __name__ == "__main__": | |||
# This configure can run both in pynative mode and graph mode | |||
context.set_context(mode=context.GRAPH_MODE, | |||
device_target=cfg.device_target) | |||
network = LeNet5() | |||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, | |||
reduction="mean") | |||
config_ck = CheckpointConfig( | |||
save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||
directory='./trained_ckpt_file/', | |||
config=config_ck) | |||
# get training dataset | |||
ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), | |||
cfg.batch_size) | |||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||
raise ValueError( | |||
"Number of micro_batches should divide evenly batch_size") | |||
# Create a factory class of DP noise mechanisms, this method is adding noise | |||
# in gradients while training. Initial_noise_multiplier is suggested to be | |||
# greater than 1.0, otherwise the privacy budget would be huge, which means | |||
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian' | |||
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' | |||
# mechanism while be constant with 'Gaussian' mechanism. | |||
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, | |||
norm_bound=cfg.norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier, | |||
noise_update='Exp') | |||
net_opt = nn.Momentum(params=network.trainable_params(), | |||
learning_rate=cfg.lr, momentum=cfg.momentum) | |||
# Create a monitor for DP training. The function of the monitor is to | |||
# compute and print the privacy budget(eps and delta) while training. | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier, | |||
per_print_times=234) | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_bound=cfg.norm_bound, | |||
noise_mech=noise_mech, | |||
network=network, | |||
loss_fn=net_loss, | |||
optimizer=net_opt, | |||
metrics={"Accuracy": Accuracy()}) | |||
LOGGER.info(TAG, "============== Starting Training ==============") | |||
model.train(cfg['epoch_size'], ds_train, | |||
callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | |||
dataset_sink_mode=cfg.dataset_sink_mode) | |||
LOGGER.info(TAG, "============== Starting Testing ==============") | |||
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | |||
param_dict = load_checkpoint(ckpt_file_name) | |||
load_param_into_net(network, param_dict) | |||
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), | |||
batch_size=cfg.batch_size) | |||
acc = model.eval(ds_eval, dataset_sink_mode=False) | |||
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) |
@@ -364,6 +364,9 @@ class ZCDPMonitor(Callback): | |||
.. math:: | |||
(ρ+2\sqrt{ρlog(1/δ)}, δ) | |||
It should be noted that ZCDPMonitor is not suitable for subsampling | |||
noise mechanisms(such as NoiseAdaGaussianRandom and NoiseGaussianRandom). | |||
The matching noise mechanism of ZCDP will be developed in the future. | |||
Reference: `Concentrated Differentially Private Gradient Descent with | |||
Adaptive per-Iteration Privacy Budget <https://arxiv.org/abs/1808.09501>`_ | |||
@@ -48,6 +48,9 @@ class ModelCoverageMetrics: | |||
train_dataset (numpy.ndarray): Training dataset used for determine | |||
the neurons' output boundaries. | |||
Raises: | |||
ValueError: If neuron_num is too big (for example, bigger than 1e+9). | |||
Examples: | |||
>>> train_images = np.random.random((10000, 128)).astype(np.float32) | |||
>>> test_images = np.random.random((5000, 128)).astype(np.float32) | |||
@@ -63,10 +66,11 @@ class ModelCoverageMetrics: | |||
self._model = check_model('model', model, Model) | |||
self._segmented_num = check_int_positive('segmented_num', segmented_num) | |||
self._neuron_num = check_int_positive('neuron_num', neuron_num) | |||
if self._neuron_num >= 1e+10: | |||
if self._neuron_num > 1e+9: | |||
msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError' \ | |||
'would occur' | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
self._lower_bounds = [np.inf]*self._neuron_num | |||
self._upper_bounds = [-np.inf]*self._neuron_num | |||