diff --git a/example/mnist_demo/dp_ada_gaussian_config.py b/example/mnist_demo/dp_ada_gaussian_config.py index 9a0dd08..03d1414 100644 --- a/example/mnist_demo/dp_ada_gaussian_config.py +++ b/example/mnist_demo/dp_ada_gaussian_config.py @@ -22,7 +22,7 @@ mnist_cfg = edict({ 'num_classes': 10, # the number of classes of model's output 'lr': 0.01, # the learning rate of model's optimizer 'momentum': 0.9, # the momentum value of model's optimizer - 'epoch_size': 10, # training epochs + 'epoch_size': 5, # training epochs 'batch_size': 256, # batch size for training 'image_height': 32, # the height of training samples 'image_width': 32, # the width of training samples @@ -31,9 +31,9 @@ mnist_cfg = edict({ 'device_target': 'Ascend', # device used 'data_path': './MNIST_unzip', # the path of training and testing data set 'dataset_sink_mode': False, # whether deliver all training data to device one time - 'micro_batches': 16, # the number of small batches split from an original batch + 'micro_batches': 32, # the number of small batches split from an original batch 'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters - 'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training + 'initial_noise_multiplier': 0.05, # the initial multiplication coefficient of the noise added to training # parameters' gradients 'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training 'optimizer': 'Momentum' # the base optimizer used for Differential privacy training diff --git a/example/mnist_demo/lenet5_config.py b/example/mnist_demo/lenet5_config.py index f7dd03b..f1a2745 100644 --- a/example/mnist_demo/lenet5_config.py +++ b/example/mnist_demo/lenet5_config.py @@ -22,7 +22,7 @@ mnist_cfg = edict({ 'num_classes': 10, # the number of classes of model's output 'lr': 0.01, # the learning rate of model's optimizer 'momentum': 0.9, # the momentum value of model's optimizer - 'epoch_size': 10, # training epochs + 'epoch_size': 5, # training epochs 'batch_size': 256, # batch size for training 'image_height': 32, # the height of training samples 'image_width': 32, # the width of training samples @@ -31,9 +31,9 @@ mnist_cfg = edict({ 'device_target': 'Ascend', # device used 'data_path': './MNIST_unzip', # the path of training and testing data set 'dataset_sink_mode': False, # whether deliver all training data to device one time - 'micro_batches': 16, # the number of small batches split from an original batch + 'micro_batches': 32, # the number of small batches split from an original batch 'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters - 'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training + 'initial_noise_multiplier': 0.05, # the initial multiplication coefficient of the noise added to training # parameters' gradients 'noise_mechanisms': 'Gaussian', # the method of adding noise in gradients while training 'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training diff --git a/example/mnist_demo/lenet5_dp.py b/example/mnist_demo/lenet5_dp.py index 65aa63c..6468cd3 100644 --- a/example/mnist_demo/lenet5_dp.py +++ b/example/mnist_demo/lenet5_dp.py @@ -155,7 +155,7 @@ if __name__ == "__main__": dataset_sink_mode=cfg.dataset_sink_mode) LOGGER.info(TAG, "============== Starting Testing ==============") - ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' + ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-5_234.ckpt' param_dict = load_checkpoint(ckpt_file_name) load_param_into_net(network, param_dict) ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), diff --git a/example/mnist_demo/lenet5_dp_ada_gaussian.py b/example/mnist_demo/lenet5_dp_ada_gaussian.py index 484c733..d2b84c4 100644 --- a/example/mnist_demo/lenet5_dp_ada_gaussian.py +++ b/example/mnist_demo/lenet5_dp_ada_gaussian.py @@ -141,7 +141,7 @@ if __name__ == "__main__": dataset_sink_mode=cfg.dataset_sink_mode) LOGGER.info(TAG, "============== Starting Testing ==============") - ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' + ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-5_234.ckpt' param_dict = load_checkpoint(ckpt_file_name) load_param_into_net(network, param_dict) ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), diff --git a/example/mnist_demo/lenet5_dp_pynative_model.py b/example/mnist_demo/lenet5_dp_pynative_model.py index ab86fdb..d6906cf 100644 --- a/example/mnist_demo/lenet5_dp_pynative_model.py +++ b/example/mnist_demo/lenet5_dp_pynative_model.py @@ -111,7 +111,7 @@ if __name__ == "__main__": dp_opt.set_mechanisms(cfg.noise_mechanisms, norm_bound=cfg.norm_bound, initial_noise_multiplier=cfg.initial_noise_multiplier, - decay_policy='Exp') + decay_policy=None) # Create a factory class of clip mechanisms, this method is to adaptive clip # gradients while training, decay_policy support 'Linear' and 'Geometric', # learning_rate is the learning rate to update clip_norm, @@ -147,7 +147,7 @@ if __name__ == "__main__": dataset_sink_mode=cfg.dataset_sink_mode) LOGGER.info(TAG, "============== Starting Testing ==============") - ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' + ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-5_234.ckpt' param_dict = load_checkpoint(ckpt_file_name) load_param_into_net(network, param_dict) ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) diff --git a/mindarmour/diff_privacy/train/model.py b/mindarmour/diff_privacy/train/model.py index 875e9d9..c68253a 100644 --- a/mindarmour/diff_privacy/train/model.py +++ b/mindarmour/diff_privacy/train/model.py @@ -656,14 +656,16 @@ class _TrainOneStepCell(Cell): record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) beta = self._zero - square_sum = self._zero - for grad in record_grad: - square_sum = self._add(square_sum, - self._reduce_sum(self._square_all(grad))) - norm_grad = self._sqrt(square_sum) - beta = self._add(beta, - self._cast(self._less(norm_grad, self._norm_bound), - mstype.float32)) + # calcu beta + if self._clip_mech is not None: + square_sum = self._zero + for grad in record_grad: + square_sum = self._add(square_sum, + self._reduce_sum(self._square_all(grad))) + norm_grad = self._sqrt(square_sum) + beta = self._add(beta, + self._cast(self._less(norm_grad, self._norm_bound), + mstype.float32)) record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._norm_bound) @@ -675,14 +677,16 @@ class _TrainOneStepCell(Cell): record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) - square_sum = self._zero - for grad in record_grad: - square_sum = self._add(square_sum, - self._reduce_sum(self._square_all(grad))) - norm_grad = self._sqrt(square_sum) - beta = self._add(beta, - self._cast(self._less(norm_grad, self._norm_bound), - mstype.float32)) + # calcu beta + if self._clip_mech is not None: + square_sum = self._zero + for grad in record_grad: + square_sum = self._add(square_sum, + self._reduce_sum(self._square_all(grad))) + norm_grad = self._sqrt(square_sum) + beta = self._add(beta, + self._cast(self._less(norm_grad, self._norm_bound), + mstype.float32)) record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, @@ -690,7 +694,6 @@ class _TrainOneStepCell(Cell): grads = self._tuple_add(grads, record_grad) total_loss = P.TensorAdd()(total_loss, loss) loss = self._div(total_loss, self._micro_float) - beta = self._div(beta, self._micro_batches) if self._noise_mech is not None: grad_noise_tuple = () @@ -710,8 +713,9 @@ class _TrainOneStepCell(Cell): grads = self.grad_reducer(grads) if self._clip_mech is not None: + beta = self._div(beta, self._micro_batches) next_norm_bound = self._clip_mech(beta, self._norm_bound) self._norm_bound = self._assign(self._norm_bound, next_norm_bound) - loss = F.depend(loss, next_norm_bound) + loss = F.depend(loss, self._norm_bound) return F.depend(loss, self.optimizer(grads))