| @@ -20,7 +20,7 @@ from easydict import EasyDict as edict | |||||
| mnist_cfg = edict({ | mnist_cfg = edict({ | ||||
| 'num_classes': 10, # the number of classes of model's output | 'num_classes': 10, # the number of classes of model's output | ||||
| 'lr': 0.01, # the learning rate of model's optimizer | |||||
| 'lr': 0.1, # the learning rate of model's optimizer | |||||
| 'momentum': 0.9, # the momentum value of model's optimizer | 'momentum': 0.9, # the momentum value of model's optimizer | ||||
| 'epoch_size': 10, # training epochs | 'epoch_size': 10, # training epochs | ||||
| 'batch_size': 256, # batch size for training | 'batch_size': 256, # batch size for training | ||||
| @@ -33,7 +33,7 @@ mnist_cfg = edict({ | |||||
| 'dataset_sink_mode': False, # whether deliver all training data to device one time | 'dataset_sink_mode': False, # whether deliver all training data to device one time | ||||
| 'micro_batches': 16, # the number of small batches split from an original batch | 'micro_batches': 16, # the number of small batches split from an original batch | ||||
| 'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | 'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | ||||
| 'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training | |||||
| 'initial_noise_multiplier': 0.2, # the initial multiplication coefficient of the noise added to training | |||||
| # parameters' gradients | # parameters' gradients | ||||
| 'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | 'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | ||||
| 'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | 'optimizer': 'Momentum' # the base optimizer used for Differential privacy training | ||||
| @@ -0,0 +1,140 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| python lenet5_dp.py --data_path /YourDataPath --micro_batches=2 | |||||
| """ | |||||
| import os | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore.train.callback import ModelCheckpoint | |||||
| from mindspore.train.callback import CheckpointConfig | |||||
| from mindspore.train.callback import LossMonitor | |||||
| from mindspore.nn.metrics import Accuracy | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| from mindspore.dataset.transforms.vision import Inter | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindarmour.diff_privacy import DPModel | |||||
| from mindarmour.diff_privacy import PrivacyMonitorFactory | |||||
| from mindarmour.diff_privacy import MechanismsFactory | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| from lenet5_net import LeNet5 | |||||
| from lenet5_config import mnist_cfg as cfg | |||||
| LOGGER = LogUtil.get_instance() | |||||
| LOGGER.set_level('INFO') | |||||
| TAG = 'Lenet5_train' | |||||
| def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||||
| num_parallel_workers=1, sparse=True): | |||||
| """ | |||||
| create dataset for training or testing | |||||
| """ | |||||
| # define dataset | |||||
| ds1 = ds.MnistDataset(data_path) | |||||
| # define operation parameters | |||||
| resize_height, resize_width = 32, 32 | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| # define map operations | |||||
| resize_op = CV.Resize((resize_height, resize_width), | |||||
| interpolation=Inter.LINEAR) | |||||
| rescale_op = CV.Rescale(rescale, shift) | |||||
| hwc2chw_op = CV.HWC2CHW() | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| # apply map operations on images | |||||
| if not sparse: | |||||
| one_hot_enco = C.OneHot(10) | |||||
| ds1 = ds1.map(input_columns="label", operations=one_hot_enco, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| type_cast_op = C.TypeCast(mstype.float32) | |||||
| ds1 = ds1.map(input_columns="label", operations=type_cast_op, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| ds1 = ds1.map(input_columns="image", operations=resize_op, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| ds1 = ds1.map(input_columns="image", operations=rescale_op, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| # apply DatasetOps | |||||
| buffer_size = 10000 | |||||
| ds1 = ds1.shuffle(buffer_size=buffer_size) | |||||
| ds1 = ds1.batch(batch_size, drop_remainder=True) | |||||
| ds1 = ds1.repeat(repeat_size) | |||||
| return ds1 | |||||
| if __name__ == "__main__": | |||||
| # This configure can run both in pynative mode and graph mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | |||||
| network = LeNet5() | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
| directory='./trained_ckpt_file/', | |||||
| config=config_ck) | |||||
| # get training dataset | |||||
| ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), | |||||
| cfg.batch_size, | |||||
| cfg.epoch_size) | |||||
| if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||||
| raise ValueError("Number of micro_batches should divide evenly batch_size") | |||||
| # Create a factory class of DP mechanisms, this method is adding noise in gradients while training. | |||||
| # Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which | |||||
| # means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||||
| # would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||||
| mech = MechanismsFactory().create(cfg.mechanisms, | |||||
| norm_bound=cfg.l2_norm_bound, | |||||
| initial_noise_multiplier=cfg.initial_noise_multiplier) | |||||
| net_opt = nn.Momentum(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||||
| # Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||||
| # and delta) while training. | |||||
| rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||||
| num_samples=60000, | |||||
| batch_size=cfg.batch_size, | |||||
| initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.l2_norm_bound, | |||||
| per_print_times=10) | |||||
| # Create the DP model for training. | |||||
| model = DPModel(micro_batches=cfg.micro_batches, | |||||
| norm_clip=cfg.l2_norm_bound, | |||||
| mech=mech, | |||||
| network=network, | |||||
| loss_fn=net_loss, | |||||
| optimizer=net_opt, | |||||
| metrics={"Accuracy": Accuracy()}) | |||||
| LOGGER.info(TAG, "============== Starting Training ==============") | |||||
| model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], | |||||
| dataset_sink_mode=cfg.dataset_sink_mode) | |||||
| LOGGER.info(TAG, "============== Starting Testing ==============") | |||||
| ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' | |||||
| param_dict = load_checkpoint(ckpt_file_name) | |||||
| load_param_into_net(network, param_dict) | |||||
| ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) | |||||
| acc = model.eval(ds_eval, dataset_sink_mode=False) | |||||
| LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) | |||||
| @@ -12,7 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """ | """ | ||||
| python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2 | |||||
| python lenet5_dp_pynative_mode.py --data_path /YourDataPath --micro_batches=2 | |||||
| """ | """ | ||||
| import os | import os | ||||
| @@ -30,8 +30,8 @@ from mindspore.dataset.transforms.vision import Inter | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindarmour.diff_privacy import DPModel | from mindarmour.diff_privacy import DPModel | ||||
| from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
| from mindarmour.diff_privacy import PrivacyMonitorFactory | from mindarmour.diff_privacy import PrivacyMonitorFactory | ||||
| from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| from lenet5_net import LeNet5 | from lenet5_net import LeNet5 | ||||
| from lenet5_config import mnist_cfg as cfg | from lenet5_config import mnist_cfg as cfg | ||||
| @@ -86,8 +86,8 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| # This configure just can run in pynative mode. | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | ||||
| network = LeNet5() | network = LeNet5() | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | ||||
| @@ -103,34 +103,26 @@ if __name__ == "__main__": | |||||
| if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | ||||
| raise ValueError("Number of micro_batches should divide evenly batch_size") | raise ValueError("Number of micro_batches should divide evenly batch_size") | ||||
| # Create a factory class of DP optimizer | |||||
| gaussian_mech = DPOptimizerClassFactory(cfg.micro_batches) | |||||
| # Set the method of adding noise in gradients while training. Initial_noise_multiplier is suggested to be greater | |||||
| # than 1.0, otherwise the privacy budget would be huge, which means that the privacy protection effect is weak. | |||||
| # mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' mechanism while | |||||
| # be constant with 'Gaussian' mechanism. | |||||
| gaussian_mech.set_mechanisms(cfg.mechanisms, | |||||
| norm_bound=cfg.l2_norm_bound, | |||||
| initial_noise_multiplier=cfg.initial_noise_multiplier) | |||||
| # Wrap the base optimizer for DP training. Momentum optimizer is suggested for LenNet5. | |||||
| net_opt = gaussian_mech.create(cfg.optimizer)(params=network.trainable_params(), | |||||
| learning_rate=cfg.lr, | |||||
| momentum=cfg.momentum) | |||||
| # Create a factory class of DP mechanisms, this method is adding noise in gradients while training. | |||||
| # Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which | |||||
| # means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||||
| # would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||||
| dp_opt = DPOptimizerClassFactory(micro_batches=cfg.micro_batches) | |||||
| dp_opt.set_mechanisms(cfg.mechanisms, | |||||
| norm_bound=cfg.l2_norm_bound, | |||||
| initial_noise_multiplier=cfg.initial_noise_multiplier) | |||||
| net_opt = dp_opt.create('Momentum')(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||||
| # Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | # Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | ||||
| # and delta) while training. | # and delta) while training. | ||||
| rdp_monitor = PrivacyMonitorFactory.create('rdp', | rdp_monitor = PrivacyMonitorFactory.create('rdp', | ||||
| num_samples=60000, | num_samples=60000, | ||||
| batch_size=cfg.batch_size, | batch_size=cfg.batch_size, | ||||
| initial_noise_multiplier=cfg.initial_noise_multiplier, | |||||
| per_print_times=50) | |||||
| initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.l2_norm_bound, | |||||
| per_print_times=10) | |||||
| # Create the DP model for training. | # Create the DP model for training. | ||||
| model = DPModel(micro_batches=cfg.micro_batches, | model = DPModel(micro_batches=cfg.micro_batches, | ||||
| norm_clip=cfg.l2_norm_bound, | norm_clip=cfg.l2_norm_bound, | ||||
| dp_mech=gaussian_mech.mech, | |||||
| mech=None, | |||||
| network=network, | network=network, | ||||
| loss_fn=net_loss, | loss_fn=net_loss, | ||||
| optimizer=net_opt, | optimizer=net_opt, | ||||
| @@ -14,8 +14,6 @@ | |||||
| """ | """ | ||||
| Noise Mechanisms. | Noise Mechanisms. | ||||
| """ | """ | ||||
| import numpy as np | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -24,6 +22,7 @@ from mindspore.common import dtype as mstype | |||||
| from mindarmour.utils._check_param import check_param_type | from mindarmour.utils._check_param import check_param_type | ||||
| from mindarmour.utils._check_param import check_value_positive | from mindarmour.utils._check_param import check_value_positive | ||||
| from mindarmour.utils._check_param import check_value_non_negative | |||||
| from mindarmour.utils._check_param import check_param_in_range | from mindarmour.utils._check_param import check_param_in_range | ||||
| @@ -62,7 +61,8 @@ class Mechanisms(Cell): | |||||
| """ | """ | ||||
| Basic class of noise generated mechanism. | Basic class of noise generated mechanism. | ||||
| """ | """ | ||||
| def construct(self, shape): | |||||
| def construct(self, gradients): | |||||
| """ | """ | ||||
| Construct function. | Construct function. | ||||
| """ | """ | ||||
| @@ -78,41 +78,47 @@ class GaussianRandom(Mechanisms): | |||||
| initial_noise_multiplier(float): Ratio of the standard deviation of | initial_noise_multiplier(float): Ratio of the standard deviation of | ||||
| Gaussian noise divided by the norm_bound, which will be used to | Gaussian noise divided by the norm_bound, which will be used to | ||||
| calculate privacy spent. Default: 1.5. | calculate privacy spent. Default: 1.5. | ||||
| mean(float): Average value of random noise. Default: 0.0. | |||||
| seed(int): Original random seed. Default: 0. | |||||
| Returns: | Returns: | ||||
| Tensor, generated noise. | |||||
| Tensor, generated noise with shape like given gradients. | |||||
| Examples: | Examples: | ||||
| >>> shape = (3, 2, 4) | |||||
| >>> gradients = Tensor([0.2, 0.9], mstype.float32) | |||||
| >>> norm_bound = 1.0 | >>> norm_bound = 1.0 | ||||
| >>> initial_noise_multiplier = 1.5 | |||||
| >>> initial_noise_multiplier = 0.1 | |||||
| >>> net = GaussianRandom(norm_bound, initial_noise_multiplier) | >>> net = GaussianRandom(norm_bound, initial_noise_multiplier) | ||||
| >>> res = net(shape) | |||||
| >>> res = net(gradients) | |||||
| >>> print(res) | >>> print(res) | ||||
| """ | """ | ||||
| def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5): | |||||
| def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5, mean=0.0, seed=0): | |||||
| super(GaussianRandom, self).__init__() | super(GaussianRandom, self).__init__() | ||||
| self._norm_bound = check_value_positive('norm_bound', norm_bound) | self._norm_bound = check_value_positive('norm_bound', norm_bound) | ||||
| self._norm_bound = Tensor(norm_bound, mstype.float32) | |||||
| self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | ||||
| initial_noise_multiplier,) | |||||
| stddev = self._norm_bound*self._initial_noise_multiplier | |||||
| self._stddev = stddev | |||||
| self._mean = 0 | |||||
| def construct(self, shape): | |||||
| initial_noise_multiplier) | |||||
| self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||||
| mean = check_param_type('mean', mean, float) | |||||
| mean = check_value_non_negative('mean', mean) | |||||
| self._mean = Tensor(mean, mstype.float32) | |||||
| self._normal = P.Normal(seed=seed) | |||||
| def construct(self, gradients): | |||||
| """ | """ | ||||
| Generated Gaussian noise. | Generated Gaussian noise. | ||||
| Args: | Args: | ||||
| shape(tuple): The shape of gradients. | |||||
| gradients(Tensor): The gradients. | |||||
| Returns: | Returns: | ||||
| Tensor, generated noise. | |||||
| Tensor, generated noise with shape like given gradients. | |||||
| """ | """ | ||||
| shape = check_param_type('shape', shape, tuple) | |||||
| noise = np.random.normal(self._mean, self._stddev, shape) | |||||
| return Tensor(noise, mstype.float32) | |||||
| shape = P.Shape()(gradients) | |||||
| stddev = P.Mul()(self._norm_bound, self._initial_noise_multiplier) | |||||
| noise = self._normal(shape, self._mean, stddev) | |||||
| return noise | |||||
| class AdaGaussianRandom(Mechanisms): | class AdaGaussianRandom(Mechanisms): | ||||
| @@ -126,54 +132,60 @@ class AdaGaussianRandom(Mechanisms): | |||||
| initial_noise_multiplier(float): Ratio of the standard deviation of | initial_noise_multiplier(float): Ratio of the standard deviation of | ||||
| Gaussian noise divided by the norm_bound, which will be used to | Gaussian noise divided by the norm_bound, which will be used to | ||||
| calculate privacy spent. Default: 5.0. | calculate privacy spent. Default: 5.0. | ||||
| noise_decay_rate(float): Hyperparameter for controlling the noise decay. | |||||
| mean(float): Average value of random noise. Default: 0.0 | |||||
| noise_decay_rate(float): Hyper parameter for controlling the noise decay. | |||||
| Default: 6e-4. | Default: 6e-4. | ||||
| decay_policy(str): Noise decay strategy include 'Step' and 'Time'. | decay_policy(str): Noise decay strategy include 'Step' and 'Time'. | ||||
| Default: 'Time'. | Default: 'Time'. | ||||
| seed(int): Original random seed. Default: 0. | |||||
| Returns: | Returns: | ||||
| Tensor, generated noise. | |||||
| Tensor, generated noise with shape like given gradients. | |||||
| Examples: | Examples: | ||||
| >>> shape = (3, 2, 4) | |||||
| >>> gradients = Tensor([0.2, 0.9], mstype.float32) | |||||
| >>> norm_bound = 1.0 | >>> norm_bound = 1.0 | ||||
| >>> initial_noise_multiplier = 0.1 | |||||
| >>> noise_decay_rate = 0.5 | |||||
| >>> initial_noise_multiplier = 5.0 | |||||
| >>> mean = 0.0 | |||||
| >>> noise_decay_rate = 6e-4 | |||||
| >>> decay_policy = "Time" | >>> decay_policy = "Time" | ||||
| >>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||||
| >>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, mean | |||||
| >>> noise_decay_rate, decay_policy) | >>> noise_decay_rate, decay_policy) | ||||
| >>> res = net(shape) | |||||
| >>> res = net(gradients) | |||||
| >>> print(res) | >>> print(res) | ||||
| """ | """ | ||||
| def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, | |||||
| noise_decay_rate=6e-4, decay_policy='Time'): | |||||
| def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, mean=0.0, | |||||
| noise_decay_rate=6e-4, decay_policy='Time', seed=0): | |||||
| super(AdaGaussianRandom, self).__init__() | super(AdaGaussianRandom, self).__init__() | ||||
| norm_bound = check_value_positive('norm_bound', norm_bound) | |||||
| initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | ||||
| initial_noise_multiplier) | initial_noise_multiplier) | ||||
| initial_noise_multiplier = Tensor(np.array(initial_noise_multiplier, np.float32)) | |||||
| self._norm_bound = Tensor(norm_bound, mstype.float32) | |||||
| initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||||
| self._initial_noise_multiplier = Parameter(initial_noise_multiplier, | self._initial_noise_multiplier = Parameter(initial_noise_multiplier, | ||||
| name='initial_noise_multiplier') | name='initial_noise_multiplier') | ||||
| self._stddev = P.Mul()(self._norm_bound, self._initial_noise_multiplier) | |||||
| self._noise_multiplier = Parameter(initial_noise_multiplier, | self._noise_multiplier = Parameter(initial_noise_multiplier, | ||||
| name='noise_multiplier') | name='noise_multiplier') | ||||
| norm_bound = check_value_positive('norm_bound', norm_bound) | |||||
| self._norm_bound = Tensor(np.array(norm_bound, np.float32)) | |||||
| mean = check_param_type('mean', mean, float) | |||||
| mean = check_value_non_negative('mean', mean) | |||||
| self._mean = Tensor(mean, mstype.float32) | |||||
| noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | ||||
| check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | ||||
| self._noise_decay_rate = Tensor(np.array(noise_decay_rate, np.float32)) | |||||
| self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32) | |||||
| if decay_policy not in ['Time', 'Step']: | if decay_policy not in ['Time', 'Step']: | ||||
| raise NameError("The decay_policy must be in ['Time', 'Step'], but " | raise NameError("The decay_policy must be in ['Time', 'Step'], but " | ||||
| "get {}".format(decay_policy)) | "get {}".format(decay_policy)) | ||||
| self._decay_policy = decay_policy | self._decay_policy = decay_policy | ||||
| self._mean = 0.0 | |||||
| self._sub = P.Sub() | self._sub = P.Sub() | ||||
| self._mul = P.Mul() | self._mul = P.Mul() | ||||
| self._add = P.TensorAdd() | self._add = P.TensorAdd() | ||||
| self._div = P.Div() | self._div = P.Div() | ||||
| self._stddev = self._update_stddev() | |||||
| self._dtype = mstype.float32 | self._dtype = mstype.float32 | ||||
| self._normal = P.Normal(seed=seed) | |||||
| self._assign = P.Assign() | |||||
| def _update_multiplier(self): | def _update_multiplier(self): | ||||
| """ Update multiplier. """ | """ Update multiplier. """ | ||||
| @@ -181,31 +193,32 @@ class AdaGaussianRandom(Mechanisms): | |||||
| temp = self._div(self._initial_noise_multiplier, | temp = self._div(self._initial_noise_multiplier, | ||||
| self._noise_multiplier) | self._noise_multiplier) | ||||
| temp = self._add(temp, self._noise_decay_rate) | temp = self._add(temp, self._noise_decay_rate) | ||||
| temp = self._div(self._initial_noise_multiplier, temp) | |||||
| self._noise_multiplier = Parameter(temp, name='noise_multiplier') | |||||
| self._noise_multiplier = self._assign(self._noise_multiplier, | |||||
| self._div(self._initial_noise_multiplier, temp)) | |||||
| else: | else: | ||||
| one = Tensor(1, self._dtype) | one = Tensor(1, self._dtype) | ||||
| temp = self._sub(one, self._noise_decay_rate) | temp = self._sub(one, self._noise_decay_rate) | ||||
| temp = self._mul(temp, self._noise_multiplier) | |||||
| self._noise_multiplier = Parameter(temp, name='noise_multiplier') | |||||
| self._noise_multiplier = self._assign(self._noise_multiplier, self._mul(temp, self._noise_multiplier)) | |||||
| return self._noise_multiplier | |||||
| def _update_stddev(self): | def _update_stddev(self): | ||||
| self._stddev = self._mul(self._noise_multiplier, self._norm_bound) | |||||
| self._stddev = self._assign(self._stddev, self._mul(self._noise_multiplier, self._norm_bound)) | |||||
| return self._stddev | return self._stddev | ||||
| def construct(self, shape): | |||||
| def construct(self, gradients): | |||||
| """ | """ | ||||
| Generate adaptive Gaussian noise. | Generate adaptive Gaussian noise. | ||||
| Args: | Args: | ||||
| shape(tuple): The shape of gradients. | |||||
| gradients(Tensor): The gradients. | |||||
| Returns: | Returns: | ||||
| Tensor, generated noise. | |||||
| Tensor, generated noise with shape like given gradients. | |||||
| """ | """ | ||||
| shape = check_param_type('shape', shape, tuple) | |||||
| noise = np.random.normal(self._mean, self._stddev.asnumpy(), | |||||
| shape) | |||||
| self._update_multiplier() | |||||
| self._update_stddev() | |||||
| return Tensor(noise, mstype.float32) | |||||
| shape = P.Shape()(gradients) | |||||
| noise = self._normal(shape, self._mean, self._stddev) | |||||
| # pylint: disable=unused-variable | |||||
| mt = self._update_multiplier() | |||||
| # pylint: disable=unused-variable | |||||
| std = self._update_stddev() | |||||
| return noise | |||||
| @@ -14,13 +14,37 @@ | |||||
| """ | """ | ||||
| Differential privacy optimizer. | Differential privacy optimizer. | ||||
| """ | """ | ||||
| import mindspore as ms | |||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory | from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory | ||||
| from mindarmour.utils._check_param import check_int_positive | from mindarmour.utils._check_param import check_int_positive | ||||
| _grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| _reciprocal = P.Reciprocal() | |||||
| @_grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | |||||
| """ grad scaling """ | |||||
| return grad * _reciprocal(scale) | |||||
| class _TupleAdd(nn.Cell): | |||||
| def __init__(self): | |||||
| super(_TupleAdd, self).__init__() | |||||
| self.add = P.TensorAdd() | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, input1, input2): | |||||
| """Add two tuple of data.""" | |||||
| out = self.hyper_map(self.add, input1, input2) | |||||
| return out | |||||
| class DPOptimizerClassFactory: | class DPOptimizerClassFactory: | ||||
| """ | """ | ||||
| @@ -36,9 +60,10 @@ class DPOptimizerClassFactory: | |||||
| >>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2) | >>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2) | ||||
| >>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5) | >>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5) | ||||
| >>> net_opt = GaussianSGD.create('Momentum')(params=network.trainable_params(), | >>> net_opt = GaussianSGD.create('Momentum')(params=network.trainable_params(), | ||||
| >>> learning_rate=cfg.lr, | |||||
| >>> momentum=cfg.momentum) | |||||
| >>> learning_rate=cfg.lr, | |||||
| >>> momentum=cfg.momentum) | |||||
| """ | """ | ||||
| def __init__(self, micro_batches=2): | def __init__(self, micro_batches=2): | ||||
| self._mech_factory = MechanismsFactory() | self._mech_factory = MechanismsFactory() | ||||
| self.mech = None | self.mech = None | ||||
| @@ -78,6 +103,7 @@ class DPOptimizerClassFactory: | |||||
| """ | """ | ||||
| Wrap original mindspore optimizer with `self._mech`. | Wrap original mindspore optimizer with `self._mech`. | ||||
| """ | """ | ||||
| class DPOptimizer(cls): | class DPOptimizer(cls): | ||||
| """ | """ | ||||
| Initialize the DPOptimizerClass. | Initialize the DPOptimizerClass. | ||||
| @@ -85,23 +111,22 @@ class DPOptimizerClassFactory: | |||||
| Returns: | Returns: | ||||
| Optimizer, Optimizer class. | Optimizer, Optimizer class. | ||||
| """ | """ | ||||
| def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
| super(DPOptimizer, self).__init__(*args, **kwargs) | super(DPOptimizer, self).__init__(*args, **kwargs) | ||||
| self._mech = mech | self._mech = mech | ||||
| self._tuple_add = _TupleAdd() | |||||
| self._hyper_map = C.HyperMap() | |||||
| self._micro_float = Tensor(micro_batches, mstype.float32) | |||||
| def construct(self, gradients): | def construct(self, gradients): | ||||
| """ | """ | ||||
| construct a compute flow. | construct a compute flow. | ||||
| """ | """ | ||||
| g_len = len(gradients) | |||||
| gradient_noise = list(gradients) | |||||
| for i in range(g_len): | |||||
| gradient_noise[i] = gradient_noise[i].asnumpy() | |||||
| gradient_noise[i] = self._mech(gradient_noise[i].shape).asnumpy() + gradient_noise[i] | |||||
| gradient_noise[i] = gradient_noise[i] / micro_batches | |||||
| gradient_noise[i] = Tensor(gradient_noise[i], ms.float32) | |||||
| gradients = tuple(gradient_noise) | |||||
| gradients = super(DPOptimizer, self).construct(gradients) | |||||
| grad_noise = self._hyper_map(self._mech, gradients) | |||||
| grads = self._tuple_add(gradients, grad_noise) | |||||
| grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||||
| gradients = super(DPOptimizer, self).construct(grads) | |||||
| return gradients | return gradients | ||||
| return DPOptimizer | return DPOptimizer | ||||
| @@ -16,7 +16,6 @@ Differential privacy model. | |||||
| """ | """ | ||||
| from easydict import EasyDict as edict | from easydict import EasyDict as edict | ||||
| import mindspore as ms | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| @@ -48,21 +47,19 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow | |||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindspore import ParameterTuple | from mindspore import ParameterTuple | ||||
| from mindarmour.diff_privacy.mechanisms import mechanisms | |||||
| from mindarmour.utils._check_param import check_param_type | from mindarmour.utils._check_param import check_param_type | ||||
| from mindarmour.utils._check_param import check_value_positive | from mindarmour.utils._check_param import check_value_positive | ||||
| from mindarmour.utils._check_param import check_int_positive | from mindarmour.utils._check_param import check_int_positive | ||||
| GRADIENT_CLIP_TYPE = 1 | GRADIENT_CLIP_TYPE = 1 | ||||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| reciprocal = P.Reciprocal() | |||||
| _grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| _reciprocal = P.Reciprocal() | |||||
| @grad_scale.register("Tensor", "Tensor") | |||||
| @_grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | def tensor_grad_scale(scale, grad): | ||||
| """ grad scaling """ | """ grad scaling """ | ||||
| return grad*reciprocal(scale) | |||||
| return grad * F.cast(_reciprocal(scale), F.dtype(grad)) | |||||
| class DPModel(Model): | class DPModel(Model): | ||||
| @@ -72,7 +69,7 @@ class DPModel(Model): | |||||
| Args: | Args: | ||||
| micro_batches (int): The number of small batches split from an original batch. Default: 2. | micro_batches (int): The number of small batches split from an original batch. Default: 2. | ||||
| norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. | ||||
| dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||||
| mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||||
| Examples: | Examples: | ||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| @@ -94,32 +91,37 @@ class DPModel(Model): | |||||
| >>> | >>> | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | ||||
| >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||||
| >>> gaussian_mech = DPOptimizerClassFactory() | |||||
| >>> gaussian_mech.set_mechanisms('Gaussian', | |||||
| >>> norm_bound=args.l2_norm_bound, | |||||
| >>> initial_noise_multiplier=args.initial_noise_multiplier) | |||||
| >>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||||
| >>> mech = MechanismsFactory().create('Gaussian', | |||||
| >>> norm_bound=args.l2_norm_bound, | |||||
| >>> initial_noise_multiplier=args.initial_noise_multiplier) | |||||
| >>> model = DPModel(micro_batches=2, | >>> model = DPModel(micro_batches=2, | ||||
| >>> norm_clip=1.0, | >>> norm_clip=1.0, | ||||
| >>> dp_mech=gaussian_mech.mech, | |||||
| >>> mech=mech, | |||||
| >>> network=net, | >>> network=net, | ||||
| >>> loss_fn=loss, | >>> loss_fn=loss, | ||||
| >>> optimizer=optim, | |||||
| >>> optimizer=net_opt, | |||||
| >>> metrics=None) | >>> metrics=None) | ||||
| >>> dataset = get_dataset() | >>> dataset = get_dataset() | ||||
| >>> model.train(2, dataset) | >>> model.train(2, dataset) | ||||
| """ | """ | ||||
| def __init__(self, micro_batches=2, norm_clip=1.0, dp_mech=None, **kwargs): | |||||
| def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs): | |||||
| if micro_batches: | if micro_batches: | ||||
| self._micro_batches = check_int_positive('micro_batches', micro_batches) | self._micro_batches = check_int_positive('micro_batches', micro_batches) | ||||
| else: | else: | ||||
| self._micro_batches = None | self._micro_batches = None | ||||
| norm_clip = check_param_type('norm_clip', norm_clip, float) | |||||
| self._norm_clip = check_value_positive('norm_clip', norm_clip) | |||||
| if isinstance(dp_mech, mechanisms.Mechanisms): | |||||
| self._dp_mech = dp_mech | |||||
| else: | |||||
| raise TypeError('dp mechanisms should be instance of class Mechansms, but got {}'.format(type(dp_mech))) | |||||
| float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float) | |||||
| self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip) | |||||
| if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||||
| raise ValueError('DPOptimizer is not supported while mech is not None') | |||||
| if mech is None: | |||||
| if "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||||
| if context.get_context('mode') != context.PYNATIVE_MODE: | |||||
| raise ValueError('DPOptimizer just support pynative mode currently.') | |||||
| else: | |||||
| raise ValueError('DPModel should set mech or DPOptimizer configure, please refer to example.') | |||||
| self._mech = mech | |||||
| super(DPModel, self).__init__(**kwargs) | super(DPModel, self).__init__(**kwargs) | ||||
| def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): | def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): | ||||
| @@ -179,14 +181,14 @@ class DPModel(Model): | |||||
| scale_update_cell=update_cell, | scale_update_cell=update_cell, | ||||
| micro_batches=self._micro_batches, | micro_batches=self._micro_batches, | ||||
| l2_norm_clip=self._norm_clip, | l2_norm_clip=self._norm_clip, | ||||
| mech=self._dp_mech).set_train() | |||||
| mech=self._mech).set_train() | |||||
| return network | return network | ||||
| network = _TrainOneStepCell(network, | network = _TrainOneStepCell(network, | ||||
| optimizer, | optimizer, | ||||
| loss_scale, | loss_scale, | ||||
| micro_batches=self._micro_batches, | micro_batches=self._micro_batches, | ||||
| l2_norm_clip=self._norm_clip, | l2_norm_clip=self._norm_clip, | ||||
| mech=self._dp_mech).set_train() | |||||
| mech=self._mech).set_train() | |||||
| return network | return network | ||||
| def _build_train_network(self): | def _build_train_network(self): | ||||
| @@ -244,6 +246,7 @@ class _ClipGradients(nn.Cell): | |||||
| Outputs: | Outputs: | ||||
| tuple[Tensor], clipped gradients. | tuple[Tensor], clipped gradients. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(_ClipGradients, self).__init__() | super(_ClipGradients, self).__init__() | ||||
| self.clip_by_norm = nn.ClipByNorm() | self.clip_by_norm = nn.ClipByNorm() | ||||
| @@ -253,7 +256,8 @@ class _ClipGradients(nn.Cell): | |||||
| """ | """ | ||||
| construct a compute flow. | construct a compute flow. | ||||
| """ | """ | ||||
| if clip_type not in (0, 1): | |||||
| # pylint: disable=consider-using-in | |||||
| if clip_type != 0 and clip_type != 1: | |||||
| return grads | return grads | ||||
| new_grads = () | new_grads = () | ||||
| @@ -268,6 +272,18 @@ class _ClipGradients(nn.Cell): | |||||
| return new_grads | return new_grads | ||||
| class _TupleAdd(nn.Cell): | |||||
| def __init__(self): | |||||
| super(_TupleAdd, self).__init__() | |||||
| self.add = P.TensorAdd() | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, input1, input2): | |||||
| """Add two tuple of data.""" | |||||
| out = self.hyper_map(self.add, input1, input2) | |||||
| return out | |||||
| class _TrainOneStepWithLossScaleCell(Cell): | class _TrainOneStepWithLossScaleCell(Cell): | ||||
| r""" | r""" | ||||
| Network training with loss scaling. | Network training with loss scaling. | ||||
| @@ -347,6 +363,9 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
| self._split = P.Split(0, self._micro_batches) | self._split = P.Split(0, self._micro_batches) | ||||
| self._clip_by_global_norm = _ClipGradients() | self._clip_by_global_norm = _ClipGradients() | ||||
| self._mech = mech | self._mech = mech | ||||
| self._tuple_add = _TupleAdd() | |||||
| self._hyper_map = C.HyperMap() | |||||
| self._micro_float = Tensor(micro_batches, mstype.float32) | |||||
| def construct(self, data, label, sens=None): | def construct(self, data, label, sens=None): | ||||
| """ | """ | ||||
| @@ -368,32 +387,28 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||||
| weights = self.weights | weights = self.weights | ||||
| record_datas = self._split(data) | record_datas = self._split(data) | ||||
| record_labels = self._split(label) | record_labels = self._split(label) | ||||
| grads = () | |||||
| # first index | # first index | ||||
| loss = self.network(record_datas[0], record_labels[0]) | loss = self.network(record_datas[0], record_labels[0]) | ||||
| scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | ||||
| record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) | record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) | ||||
| record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | ||||
| grad_sum = list(record_grad) | |||||
| grad_len = len(record_grad) | |||||
| for i in range(grad_len): | |||||
| grad_sum[i] = grad_sum[i].asnumpy() | |||||
| grads = record_grad | |||||
| total_loss = loss | |||||
| for i in range(1, self._micro_batches): | for i in range(1, self._micro_batches): | ||||
| loss = self.network(record_datas[i], record_labels[i]) | loss = self.network(record_datas[i], record_labels[i]) | ||||
| scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) | ||||
| record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) | record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) | ||||
| record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | ||||
| for j in range(grad_len): | |||||
| grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy() | |||||
| grads = self._tuple_add(grads, record_grad) | |||||
| total_loss = P.TensorAdd()(total_loss, loss) | |||||
| loss = P.Div()(total_loss, self._micro_float) | |||||
| for i in range(grad_len): | |||||
| grad_sum[i] = Tensor(grad_sum[i], ms.float32) | |||||
| grads = tuple(grad_sum) | |||||
| loss = self.network(data, label) | |||||
| if self._mech is not None: | |||||
| grad_noise = self._hyper_map(self._mech, grads) | |||||
| grads = self._tuple_add(grads, grad_noise) | |||||
| grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||||
| grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) | |||||
| # apply grad reducer on grads | # apply grad reducer on grads | ||||
| grads = self.grad_reducer(grads) | grads = self.grad_reducer(grads) | ||||
| # get the overflow buffer | # get the overflow buffer | ||||
| @@ -474,6 +489,9 @@ class _TrainOneStepCell(Cell): | |||||
| self._split = P.Split(0, self._micro_batches) | self._split = P.Split(0, self._micro_batches) | ||||
| self._clip_by_global_norm = _ClipGradients() | self._clip_by_global_norm = _ClipGradients() | ||||
| self._mech = mech | self._mech = mech | ||||
| self._tuple_add = _TupleAdd() | |||||
| self._hyper_map = C.HyperMap() | |||||
| self._micro_float = Tensor(micro_batches, mstype.float32) | |||||
| def construct(self, data, label): | def construct(self, data, label): | ||||
| """ | """ | ||||
| @@ -486,23 +504,21 @@ class _TrainOneStepCell(Cell): | |||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | ||||
| record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) | record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) | ||||
| record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | ||||
| grad_sum = list(record_grad) | |||||
| grad_len = len(record_grad) | |||||
| for i in range(grad_len): | |||||
| grad_sum[i] = grad_sum[i].asnumpy() | |||||
| grads = record_grad | |||||
| total_loss = loss | |||||
| for i in range(1, self._micro_batches): | for i in range(1, self._micro_batches): | ||||
| loss = self.network(record_datas[i], record_labels[i]) | loss = self.network(record_datas[i], record_labels[i]) | ||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | ||||
| record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) | record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) | ||||
| record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) | ||||
| for j in range(grad_len): | |||||
| grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy() | |||||
| for i in range(grad_len): | |||||
| grad_sum[i] = Tensor(grad_sum[i], ms.float32) | |||||
| grads = tuple(grad_sum) | |||||
| loss = self.network(data, label) | |||||
| grads = self._tuple_add(grads, record_grad) | |||||
| total_loss = P.TensorAdd()(total_loss, loss) | |||||
| loss = P.Div()(total_loss, self._micro_float) | |||||
| if self._mech is not None: | |||||
| grad_noise = self._hyper_map(self._mech, grads) | |||||
| grads = self._tuple_add(grads, grad_noise) | |||||
| grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) | |||||
| if self.reducer_flag: | if self.reducer_flag: | ||||
| # apply grad reducer on grads | # apply grad reducer on grads | ||||
| @@ -18,7 +18,7 @@ from setuptools import setup | |||||
| from setuptools.command.egg_info import egg_info | from setuptools.command.egg_info import egg_info | ||||
| from setuptools.command.build_py import build_py | from setuptools.command.build_py import build_py | ||||
| version = '0.3.0' | |||||
| version = '0.5.0' | |||||
| cur_dir = os.path.dirname(os.path.realpath(__file__)) | cur_dir = os.path.dirname(os.path.realpath(__file__)) | ||||
| pkg_dir = os.path.join(cur_dir, 'build') | pkg_dir = os.path.join(cur_dir, 'build') | ||||
| @@ -17,6 +17,8 @@ different Privacy test. | |||||
| import pytest | import pytest | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindarmour.diff_privacy import GaussianRandom | from mindarmour.diff_privacy import GaussianRandom | ||||
| from mindarmour.diff_privacy import AdaGaussianRandom | from mindarmour.diff_privacy import AdaGaussianRandom | ||||
| from mindarmour.diff_privacy import MechanismsFactory | from mindarmour.diff_privacy import MechanismsFactory | ||||
| @@ -26,13 +28,13 @@ from mindarmour.diff_privacy import MechanismsFactory | |||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
| def test_gaussian(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| shape = (3, 2, 4) | |||||
| def test_graph_gaussian(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| grad = Tensor([3, 2, 4], mstype.float32) | |||||
| norm_bound = 1.0 | norm_bound = 1.0 | ||||
| initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
| net = GaussianRandom(norm_bound, initial_noise_multiplier) | net = GaussianRandom(norm_bound, initial_noise_multiplier) | ||||
| res = net(shape) | |||||
| res = net(grad) | |||||
| print(res) | print(res) | ||||
| @@ -40,42 +42,99 @@ def test_gaussian(): | |||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
| def test_ada_gaussian(): | |||||
| def test_pynative_gaussian(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | ||||
| shape = (3, 2, 4) | |||||
| grad = Tensor([3, 2, 4], mstype.float32) | |||||
| norm_bound = 1.0 | |||||
| initial_noise_multiplier = 0.1 | |||||
| net = GaussianRandom(norm_bound, initial_noise_multiplier) | |||||
| res = net(grad) | |||||
| print(res) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_graph_ada_gaussian(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| grad = Tensor([3, 2, 4], mstype.float32) | |||||
| norm_bound = 1.0 | norm_bound = 1.0 | ||||
| initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
| noise_decay_rate = 0.5 | |||||
| decay_policy = "Step" | |||||
| alpha = 0.5 | |||||
| decay_policy = 'Step' | |||||
| net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | ||||
| noise_decay_rate, decay_policy) | |||||
| res = net(shape) | |||||
| noise_decay_rate=alpha, decay_policy=decay_policy) | |||||
| res = net(grad) | |||||
| print(res) | print(res) | ||||
| def test_factory(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| shape = (3, 2, 4) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_graph_factory(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| grad = Tensor([3, 2, 4], mstype.float32) | |||||
| norm_bound = 1.0 | norm_bound = 1.0 | ||||
| initial_noise_multiplier = 0.1 | initial_noise_multiplier = 0.1 | ||||
| noise_decay_rate = 0.5 | |||||
| decay_policy = "Step" | |||||
| alpha = 0.5 | |||||
| decay_policy = 'Step' | |||||
| noise_mechanism = MechanismsFactory() | noise_mechanism = MechanismsFactory() | ||||
| noise_construct = noise_mechanism.create('Gaussian', | noise_construct = noise_mechanism.create('Gaussian', | ||||
| norm_bound, | norm_bound, | ||||
| initial_noise_multiplier) | initial_noise_multiplier) | ||||
| noise = noise_construct(shape) | |||||
| noise = noise_construct(grad) | |||||
| print('Gaussian noise: ', noise) | print('Gaussian noise: ', noise) | ||||
| ada_mechanism = MechanismsFactory() | ada_mechanism = MechanismsFactory() | ||||
| ada_noise_construct = ada_mechanism.create('AdaGaussian', | ada_noise_construct = ada_mechanism.create('AdaGaussian', | ||||
| norm_bound, | norm_bound, | ||||
| initial_noise_multiplier, | initial_noise_multiplier, | ||||
| noise_decay_rate, | |||||
| decay_policy) | |||||
| ada_noise = ada_noise_construct(shape) | |||||
| noise_decay_rate=alpha, | |||||
| decay_policy=decay_policy) | |||||
| ada_noise = ada_noise_construct(grad) | |||||
| print('ada noise: ', ada_noise) | print('ada noise: ', ada_noise) | ||||
| if __name__ == '__main__': | |||||
| # device_target can be "CPU", "GPU" or "Ascend" | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_pynative_ada_gaussian(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | ||||
| grad = Tensor([3, 2, 4], mstype.float32) | |||||
| norm_bound = 1.0 | |||||
| initial_noise_multiplier = 0.1 | |||||
| alpha = 0.5 | |||||
| decay_policy = 'Step' | |||||
| net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||||
| noise_decay_rate=alpha, decay_policy=decay_policy) | |||||
| res = net(grad) | |||||
| print(res) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_pynative_factory(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| grad = Tensor([3, 2, 4], mstype.float32) | |||||
| norm_bound = 1.0 | |||||
| initial_noise_multiplier = 0.1 | |||||
| alpha = 0.5 | |||||
| decay_policy = 'Step' | |||||
| noise_mechanism = MechanismsFactory() | |||||
| noise_construct = noise_mechanism.create('Gaussian', | |||||
| norm_bound, | |||||
| initial_noise_multiplier) | |||||
| noise = noise_construct(grad) | |||||
| print('Gaussian noise: ', noise) | |||||
| ada_mechanism = MechanismsFactory() | |||||
| ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||||
| norm_bound, | |||||
| initial_noise_multiplier, | |||||
| noise_decay_rate=alpha, | |||||
| decay_policy=decay_policy) | |||||
| ada_noise = ada_noise_construct(grad) | |||||
| print('ada noise: ', ada_noise) | |||||
| @@ -21,13 +21,15 @@ from mindspore import nn | |||||
| from mindspore import context | from mindspore import context | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
| from mindarmour.diff_privacy import DPModel | from mindarmour.diff_privacy import DPModel | ||||
| from mindarmour.diff_privacy import MechanismsFactory | |||||
| from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
| from test_network import LeNet5 | from test_network import LeNet5 | ||||
| def dataset_generator(batch_size, batches): | def dataset_generator(batch_size, batches): | ||||
| """mock training data.""" | |||||
| data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) | data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) | ||||
| label = np.random.randint(0, 10, batches * batch_size).astype(np.int32) | label = np.random.randint(0, 10, batches * batch_size).astype(np.int32) | ||||
| for i in range(batches): | for i in range(batches): | ||||
| @@ -39,7 +41,7 @@ def dataset_generator(batch_size, batches): | |||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
| @pytest.mark.env_card | @pytest.mark.env_card | ||||
| @pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
| def test_dp_model(): | |||||
| def test_dp_model_pynative_mode(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | ||||
| l2_norm_bound = 1.0 | l2_norm_bound = 1.0 | ||||
| initial_noise_multiplier = 0.01 | initial_noise_multiplier = 0.01 | ||||
| @@ -47,21 +49,50 @@ def test_dp_model(): | |||||
| batch_size = 32 | batch_size = 32 | ||||
| batches = 128 | batches = 128 | ||||
| epochs = 1 | epochs = 1 | ||||
| micro_batches = 2 | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches) | |||||
| factory_opt.set_mechanisms('Gaussian', | |||||
| norm_bound=l2_norm_bound, | |||||
| initial_noise_multiplier=initial_noise_multiplier) | |||||
| net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| model = DPModel(micro_batches=micro_batches, | |||||
| norm_clip=l2_norm_bound, | |||||
| mech=None, | |||||
| network=network, | |||||
| loss_fn=loss, | |||||
| optimizer=net_opt, | |||||
| metrics=None) | |||||
| ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | |||||
| ms_ds.set_dataset_size(batch_size * batches) | |||||
| model.train(epochs, ms_ds, dataset_sink_mode=False) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_card | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_dp_model_with_graph_mode(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| l2_norm_bound = 1.0 | |||||
| initial_noise_multiplier = 0.01 | |||||
| network = LeNet5() | |||||
| batch_size = 32 | |||||
| batches = 128 | |||||
| epochs = 1 | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | ||||
| gaussian_mech = DPOptimizerClassFactory(micro_batches=2) | |||||
| gaussian_mech.set_mechanisms('Gaussian', | |||||
| norm_bound=l2_norm_bound, | |||||
| initial_noise_multiplier=initial_noise_multiplier) | |||||
| net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), | |||||
| learning_rate=0.1, | |||||
| momentum=0.9) | |||||
| mech = MechanismsFactory().create('Gaussian', | |||||
| norm_bound=l2_norm_bound, | |||||
| initial_noise_multiplier=initial_noise_multiplier) | |||||
| net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| model = DPModel(micro_batches=2, | model = DPModel(micro_batches=2, | ||||
| norm_clip=l2_norm_bound, | norm_clip=l2_norm_bound, | ||||
| dp_mech=gaussian_mech.mech, | |||||
| mech=mech, | |||||
| network=network, | network=network, | ||||
| loss_fn=loss, | loss_fn=loss, | ||||
| optimizer=net_opt, | optimizer=net_opt, | ||||
| metrics=None) | metrics=None) | ||||
| ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) | ||||
| ms_ds.set_dataset_size(batch_size * batches) | ms_ds.set_dataset_size(batch_size * batches) | ||||
| model.train(epochs, ms_ds) | |||||
| model.train(epochs, ms_ds, dataset_sink_mode=False) | |||||
| @@ -21,6 +21,7 @@ from mindarmour.diff_privacy import DPOptimizerClassFactory | |||||
| from test_network import LeNet5 | from test_network import LeNet5 | ||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||