From b2e0934bf743f8e29778ed97a4bce0351e89bc00 Mon Sep 17 00:00:00 2001 From: zheng-huanhuan Date: Thu, 30 Apr 2020 21:43:38 +0800 Subject: [PATCH] 5-month feature: add differential privacy train and optimizer. --- example/mnist_demo/lenet5_config.py | 32 ++ example/mnist_demo/lenet5_dp_model_train.py | 151 +++++ mindarmour/diff_privacy/__init__.py | 8 +- .../diff_privacy/mechanisms/mechanisms.py | 4 - mindarmour/diff_privacy/monitor/monitor.py | 2 +- mindarmour/diff_privacy/optimizer/__init__.py | 0 .../diff_privacy/optimizer/optimizer.py | 116 ++++ mindarmour/diff_privacy/train/__init__.py | 0 mindarmour/diff_privacy/train/model.py | 515 ++++++++++++++++++ .../{mechanisms => }/test_mechanisms.py | 0 .../python/diff_privacy/test_model_train.py | 65 +++ tests/ut/python/diff_privacy/test_monitor.py | 2 +- .../ut/python/diff_privacy/test_optimizer.py | 76 +++ 13 files changed, 964 insertions(+), 7 deletions(-) create mode 100644 example/mnist_demo/lenet5_config.py create mode 100644 example/mnist_demo/lenet5_dp_model_train.py create mode 100644 mindarmour/diff_privacy/optimizer/__init__.py create mode 100644 mindarmour/diff_privacy/optimizer/optimizer.py create mode 100644 mindarmour/diff_privacy/train/__init__.py create mode 100644 mindarmour/diff_privacy/train/model.py rename tests/ut/python/diff_privacy/{mechanisms => }/test_mechanisms.py (100%) create mode 100644 tests/ut/python/diff_privacy/test_model_train.py create mode 100644 tests/ut/python/diff_privacy/test_optimizer.py diff --git a/example/mnist_demo/lenet5_config.py b/example/mnist_demo/lenet5_config.py new file mode 100644 index 0000000..405b5ff --- /dev/null +++ b/example/mnist_demo/lenet5_config.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py +""" + +from easydict import EasyDict as edict + +mnist_cfg = edict({ + 'num_classes': 10, + 'lr': 0.01, + 'momentum': 0.9, + 'epoch_size': 10, + 'batch_size': 32, + 'buffer_size': 1000, + 'image_height': 32, + 'image_width': 32, + 'save_checkpoint_steps': 1875, + 'keep_checkpoint_max': 10, +}) diff --git a/example/mnist_demo/lenet5_dp_model_train.py b/example/mnist_demo/lenet5_dp_model_train.py new file mode 100644 index 0000000..6765523 --- /dev/null +++ b/example/mnist_demo/lenet5_dp_model_train.py @@ -0,0 +1,151 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2 +""" +import os +import argparse + +import mindspore.nn as nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint +from mindspore.train.callback import CheckpointConfig +from mindspore.train.callback import LossMonitor +from mindspore.nn.metrics import Accuracy +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as CV +import mindspore.dataset.transforms.c_transforms as C +from mindspore.dataset.transforms.vision import Inter +import mindspore.common.dtype as mstype + +from mindarmour.diff_privacy import DPModel +from mindarmour.diff_privacy import DPOptimizerClassFactory +from mindarmour.diff_privacy import PrivacyMonitorFactory +from mindarmour.utils.logger import LogUtil +from lenet5_net import LeNet5 +from lenet5_config import mnist_cfg as cfg + + +LOGGER = LogUtil.get_instance() +TAG = 'Lenet5_train' + + +def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1, sparse=True): + """ + create dataset for training or testing + """ + # define dataset + ds1 = ds.MnistDataset(data_path) + + # define operation parameters + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), + interpolation=Inter.LINEAR) + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + # apply map operations on images + if not sparse: + one_hot_enco = C.OneHot(10) + ds1 = ds1.map(input_columns="label", operations=one_hot_enco, + num_parallel_workers=num_parallel_workers) + type_cast_op = C.TypeCast(mstype.float32) + ds1 = ds1.map(input_columns="label", operations=type_cast_op, + num_parallel_workers=num_parallel_workers) + ds1 = ds1.map(input_columns="image", operations=resize_op, + num_parallel_workers=num_parallel_workers) + ds1 = ds1.map(input_columns="image", operations=rescale_op, + num_parallel_workers=num_parallel_workers) + ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, + num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + buffer_size = 10000 + ds1 = ds1.shuffle(buffer_size=buffer_size) + ds1 = ds1.batch(batch_size, drop_remainder=True) + ds1 = ds1.repeat(repeat_size) + + return ds1 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='MindSpore MNIST Example') + parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: Ascend)') + parser.add_argument('--data_path', type=str, default="./MNIST_unzip", + help='path where the dataset is saved') + parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') + parser.add_argument('--micro_batches', type=float, default=None, + help='optional, if use differential privacy, need to set micro_batches') + parser.add_argument('--l2_norm_bound', type=float, default=1, + help='optional, if use differential privacy, need to set l2_norm_bound') + parser.add_argument('--initial_noise_multiplier', type=float, default=0.001, + help='optional, if use differential privacy, need to set initial_noise_multiplier') + args = parser.parse_args() + + context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, enable_mem_reuse=False) + + network = LeNet5() + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory='./trained_ckpt_file/', + config=config_ck) + + ds_train = generate_mnist_dataset(os.path.join(args.data_path, "train"), + cfg.batch_size, + cfg.epoch_size) + + if args.micro_batches and cfg.batch_size % args.micro_batches != 0: + raise ValueError("Number of micro_batches should divide evenly batch_size") + gaussian_mech = DPOptimizerClassFactory(args.micro_batches) + gaussian_mech.set_mechanisms('Gaussian', + norm_bound=args.l2_norm_bound, + initial_noise_multiplier=args.initial_noise_multiplier) + net_opt = gaussian_mech.create('Momentum')(params=network.trainable_params(), + learning_rate=cfg.lr, + momentum=cfg.momentum) + micro_size = int(cfg.batch_size // args.micro_batches) + rdp_monitor = PrivacyMonitorFactory.create('rdp', + num_samples=60000, + batch_size=micro_size, + initial_noise_multiplier=args.initial_noise_multiplier, + per_print_times=10) + model = DPModel(micro_batches=args.micro_batches, + norm_clip=args.l2_norm_bound, + dp_mech=gaussian_mech.mech, + network=network, + loss_fn=net_loss, + optimizer=net_opt, + metrics={"Accuracy": Accuracy()}) + + LOGGER.info(TAG, "============== Starting Training ==============") + model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], + dataset_sink_mode=args.dataset_sink_mode) + + LOGGER.info(TAG, "============== Starting Testing ==============") + ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + param_dict = load_checkpoint(ckpt_file_name) + load_param_into_net(network, param_dict) + ds_eval = generate_mnist_dataset(os.path.join(args.data_path, 'test'), batch_size=cfg.batch_size) + acc = model.eval(ds_eval, dataset_sink_mode=False) + LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) diff --git a/mindarmour/diff_privacy/__init__.py b/mindarmour/diff_privacy/__init__.py index 9c45184..bf8a149 100644 --- a/mindarmour/diff_privacy/__init__.py +++ b/mindarmour/diff_privacy/__init__.py @@ -4,7 +4,13 @@ This module provide Differential Privacy feature to protect user privacy. from .mechanisms.mechanisms import GaussianRandom from .mechanisms.mechanisms import AdaGaussianRandom from .mechanisms.mechanisms import MechanismsFactory +from .monitor.monitor import PrivacyMonitorFactory +from .optimizer.optimizer import DPOptimizerClassFactory +from .train.model import DPModel __all__ = ['GaussianRandom', 'AdaGaussianRandom', - 'MechanismsFactory'] + 'MechanismsFactory', + 'PrivacyMonitorFactory', + 'DPOptimizerClassFactory', + 'DPModel'] diff --git a/mindarmour/diff_privacy/mechanisms/mechanisms.py b/mindarmour/diff_privacy/mechanisms/mechanisms.py index 3e82a5f..1a8de17 100644 --- a/mindarmour/diff_privacy/mechanisms/mechanisms.py +++ b/mindarmour/diff_privacy/mechanisms/mechanisms.py @@ -60,10 +60,6 @@ class Mechanisms(Cell): """ Basic class of noise generated mechanism. """ - - def __init__(self): - pass - def construct(self, shape): """ Construct function. diff --git a/mindarmour/diff_privacy/monitor/monitor.py b/mindarmour/diff_privacy/monitor/monitor.py index 4227862..ca7bdf7 100644 --- a/mindarmour/diff_privacy/monitor/monitor.py +++ b/mindarmour/diff_privacy/monitor/monitor.py @@ -47,7 +47,7 @@ class PrivacyMonitorFactory: parameters used for creating a privacy monitor. Returns: - PrivacyMonitor, a privacy monitor. + Callback, a privacy monitor. Examples: >>> rdp = PrivacyMonitorFactory.create(policy='rdp', diff --git a/mindarmour/diff_privacy/optimizer/__init__.py b/mindarmour/diff_privacy/optimizer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mindarmour/diff_privacy/optimizer/optimizer.py b/mindarmour/diff_privacy/optimizer/optimizer.py new file mode 100644 index 0000000..e16799f --- /dev/null +++ b/mindarmour/diff_privacy/optimizer/optimizer.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Differential privacy optimizer. +""" +import mindspore as ms +from mindspore import nn +from mindspore import Tensor + +from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory + + +class DPOptimizerClassFactory: + """ + Factory class of Optimizer. + + Args: + micro_batches (int): The number of small batches split from an origianl batch. Default: None. + + Returns: + Optimizer, Optimizer class + + Examples: + >>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2) + >>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) + >>> net_opt = GaussianSGD.create('SGD')(params=network.trainable_params(), + >>> learning_rate=cfg.lr, + >>> momentum=cfg.momentum) + """ + def __init__(self, micro_batches=None): + self._mech_factory = MechanismsFactory() + self.mech = None + self._micro_batches = micro_batches + + def set_mechanisms(self, policy, *args, **kwargs): + """ + Get noise mechanism object. + + Args: + policy (str): Choose mechanism type. + """ + self.mech = self._mech_factory.create(policy, *args, **kwargs) + + def create(self, policy, *args, **kwargs): + """ + Create DP optimizer. + + Args: + policy (str): Choose original optimizer type. + + Returns: + Optimizer, A optimizer with DP. + """ + if policy == 'SGD': + cls = self._get_dp_optimizer_class(nn.SGD, self.mech, self._micro_batches, *args, **kwargs) + return cls + if policy == 'Momentum': + cls = self._get_dp_optimizer_class(nn.Momentum, self.mech, self._micro_batches, *args, **kwargs) + return cls + if policy == 'Adam': + cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs) + return cls + if policy == 'AdamWeightDecay': + cls = self._get_dp_optimizer_class(nn.AdamWeightDecay, self.mech, self._micro_batches, *args, **kwargs) + return cls + if policy == 'AdamWeightDecayDynamicLR': + cls = self._get_dp_optimizer_class(nn.AdamWeightDecayDynamicLR, + self.mech, + self._micro_batches, + *args, **kwargs) + return cls + raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'AdamWeightDecay', " + "'Adam', 'AdamWeightDecayDynamicLR']".format(policy)) + + def _get_dp_optimizer_class(self, cls, mech, micro_batches): + """ + Wrap original mindspore optimizer with `self._mech`. + """ + class DPOptimizer(cls): + """ + Initialize the DPOptimizerClass. + + Returns: + Optimizer, Optimizer class. + """ + def __init__(self, *args, **kwargs): + super(DPOptimizer, self).__init__(*args, **kwargs) + self._mech = mech + + def construct(self, gradients): + """ + construct a compute flow. + """ + g_len = len(gradients) + gradient_noise = list(gradients) + for i in range(g_len): + gradient_noise[i] = gradient_noise[i].asnumpy() + gradient_noise[i] = self._mech(gradient_noise[i].shape).asnumpy() + gradient_noise[i] + gradient_noise[i] = gradient_noise[i] / micro_batches + gradient_noise[i] = Tensor(gradient_noise[i], ms.float32) + gradients = tuple(gradient_noise) + + gradients = super(DPOptimizer, self).construct(gradients) + return gradients + return DPOptimizer diff --git a/mindarmour/diff_privacy/train/__init__.py b/mindarmour/diff_privacy/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mindarmour/diff_privacy/train/model.py b/mindarmour/diff_privacy/train/model.py new file mode 100644 index 0000000..20e14a7 --- /dev/null +++ b/mindarmour/diff_privacy/train/model.py @@ -0,0 +1,515 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Differential privacy model. +""" +from easydict import EasyDict as edict + +import mindspore as ms +from mindspore.train.model import Model +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from mindspore.train import amp +from mindspore.train.amp import _config_level +from mindspore.common import dtype as mstype +from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell +from mindspore.parallel._utils import _get_parallel_mode +from mindspore.train.model import ParallelMode +from mindspore.train.amp import _do_keep_batchnorm_fp32 +from mindspore.train.amp import _add_loss_network +from mindspore import context +from mindspore import nn +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops.operations import NPUGetFloatStatus +from mindspore.ops.operations import NPUAllocFloatStatus +from mindspore.ops.operations import NPUClearFloatStatus +from mindspore.ops.operations import ReduceSum +from mindspore.ops.operations import LessEqual +from mindspore.ops.operations import ControlDepend +from mindspore.parallel._utils import _get_mirror_mean +from mindspore.parallel._utils import _get_device_num +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.common.parameter import Parameter +from mindspore.nn.wrap.loss_scale import _grad_overflow +from mindspore.nn import Cell +from mindspore import ParameterTuple + + +GRADIENT_CLIP_TYPE = 1 +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad*reciprocal(scale) + + +class DPModel(Model): + """ + This class is overload mindspore.train.model.Model. + + Args: + micro_batches (int): The number of small batches split from an origianl batch. Default: None. + norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None. + dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. + + Examples: + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') + >>> self.bn = nn.BatchNorm2d(64) + >>> self.relu = nn.ReLU() + >>> self.flatten = nn.Flatten() + >>> self.fc = nn.Dense(64*224*224, 12) # padding=0 + >>> + >>> def construct(self, x): + >>> x = self.conv(x) + >>> x = self.bn(x) + >>> x = self.relu(x) + >>> x = self.flatten(x) + >>> out = self.fc(x) + >>> return out + >>> + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> gaussian_mech = DPOptimizerClassFactory() + >>> gaussian_mech.set_mechanisms('Gaussian', + >>> norm_bound=args.l2_norm_bound, + >>> initial_noise_multiplier=args.initial_noise_multiplier) + >>> model = DPModel(micro_batches=2, + >>> norm_clip=1, + >>> dp_mech=gaussian_mech.mech, + >>> network=net, + >>> loss_fn=loss, + >>> optimizer=optim, + >>> metrics=None) + >>> dataset = get_dataset() + >>> model.train(2, dataset) + """ + def __init__(self, micro_batches=None, norm_clip=None, dp_mech=None, **kwargs): + if micro_batches: + self._micro_batches = int(micro_batches) + else: + self._micro_batches = None + self._norm_clip = norm_clip + self._dp_mech = dp_mech + super(DPModel, self).__init__(**kwargs) + + def amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): + """ + Build the mixed precision training cell automatically. + + Args: + network (Cell): Definition of the network. + loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside. + Default: None. + optimizer (Optimizer): Optimizer to update the Parameter. + level (str): Supports [O0, O2]. Default: "O0". + + - O0: Do not change. + - O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32, + using dynamic loss scale. + + cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. + If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. + keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. + loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else + scale the loss by LossScaleManager. If set, overwrite the level setting. + """ + validator.check_value_type('network', network, nn.Cell, None) + validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) + validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) + self._check_kwargs(kwargs) + config = dict(_config_level[level], **kwargs) + config = edict(config) + + if config.cast_model_type == mstype.float16: + network.to_float(mstype.float16) + + if config.keep_batchnorm_fp32: + _do_keep_batchnorm_fp32(network) + + if loss_fn: + network = _add_loss_network(network, loss_fn, config.cast_model_type) + + if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network = _VirtualDatasetCell(network) + + loss_scale = 1.0 + if config.loss_scale_manager is not None: + loss_scale_manager = config.loss_scale_manager + loss_scale = loss_scale_manager.get_loss_scale() + update_cell = loss_scale_manager.get_update_cell() + if update_cell is not None: + # only cpu not support `TrainOneStepWithLossScaleCell` for control flow. + if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU": + raise ValueError("Only `loss_scale_manager=None` and " + "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" + "are supported in current version. If you use `O2` option, please" + "use `loss_scale_manager=None` or `FixedLossScaleManager`") + network = _TrainOneStepWithLossScaleCell(network, + optimizer, + scale_update_cell=update_cell, + micro_batches=self._micro_batches, + l2_norm_clip=self._norm_clip, + mech=self._dp_mech).set_train() + return network + network = _TrainOneStepCell(network, + optimizer, + loss_scale, + micro_batches=self._micro_batches, + l2_norm_clip=self._norm_clip, + mech=self._dp_mech).set_train() + return network + + def _build_train_network(self): + """Build train network""" + network = self._network + if self._micro_batches: + if self._optimizer: + if self._loss_scale_manager_set: + network = self.amp_build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + loss_scale_manager=self._loss_scale_manager, + keep_batchnorm_fp32=self._keep_bn_fp32) + else: + network = self.amp_build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + keep_batchnorm_fp32=self._keep_bn_fp32) + elif self._loss_fn: + network = nn.WithLossCell(network, self._loss_fn) + else: + if self._optimizer: + if self._loss_scale_manager_set: + network = amp.build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + loss_scale_manager=self._loss_scale_manager, + keep_batchnorm_fp32=self._keep_bn_fp32) + else: + network = amp.build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + keep_batchnorm_fp32=self._keep_bn_fp32) + elif self._loss_fn: + network = nn.WithLossCell(network, self._loss_fn) + + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network.set_auto_parallel() + return network + + +class _ClipGradients(nn.Cell): + """ + Clip gradients. + + Inputs: + grads (tuple[Tensor]): Gradients. + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + + Outputs: + tuple[Tensor], clipped gradients. + """ + def __init__(self): + super(_ClipGradients, self).__init__() + self.clip_by_norm = nn.ClipByNorm() + self.dtype = P.DType() + + def construct(self, grads, clip_type, clip_value): + """ + construct a compute flow. + """ + if clip_type not in (0, 1): + return grads + + new_grads = () + for grad in grads: + if clip_type == 0: + t = C.clip_by_value(grad, F.tuple_to_array((-clip_value,)), + F.tuple_to_array((clip_value,))) + else: + t = self.clip_by_norm(grad, F.tuple_to_array((clip_value,))) + new_grads = new_grads + (t,) + + return new_grads + + +class _TrainOneStepWithLossScaleCell(Cell): + r""" + Network training with loss scaling. + + This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update + Cell as args. The loss scale value can be updated in both host side or device side. The + TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input + data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should + be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`. + If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored. + + Args: + network (Cell): The training network. + optimizer (Cell): Optimizer for updating the weights. + scale_update_cell(Cell): The loss scaling update logic cell. Default: None. + micro_batches (int): The number of small batches split from an origianl batch. Default: None. + l2_norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None. + mech (Mechanisms): The object can generate the different type of noise. Default: None. + + Inputs: + - **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + - **scaling_sens** (Tensor) - Tensor of shape :math:`()`. + + Outputs: + Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value. + + - **loss** (Tensor) - Tensor with shape :math:`()`. + - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. + - **loss_scale** (Tensor) - Tensor with shape :math:`()`. + + Examples: + >>> net_with_loss = Net() + >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) + >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager) + >>> train_network.set_train() + >>> + >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) + >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) + >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) + >>> output = train_network(inputs, label, scaling_sens) + """ + + def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=None, mech=None): + super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.add_flags(defer_inline=True) + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.hyper_map = C.HyperMap() + if context.get_context("device_target") == "GPU": + self.gpu_target = True + self.float_status = P.FloatStatus() + self.addn = P.AddN() + self.reshape = P.Reshape() + else: + self.gpu_target = False + self.alloc_status = NPUAllocFloatStatus() + self.get_status = NPUGetFloatStatus() + self.clear_status = NPUClearFloatStatus() + self.reduce_sum = ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = LessEqual() + self.depend_parameter_use = ControlDepend(depend_mode=1) + self.allreduce = P.AllReduce() + self.parallel_mode = _get_parallel_mode() + self.grad_reducer = F.identity + self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] + if self.reducer_flag: + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE + + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + self.add_flags(has_effect=True) + + # dp params + self._micro_batches = micro_batches + self._l2_norm = l2_norm_clip + self._split = P.Split(0, self._micro_batches) + self._clip_by_global_norm = _ClipGradients() + self._mech = mech + + def construct(self, data, label, sens=None): + """ + construct a compute flow. + """ + init = False + if not self.gpu_target: + # init overflow buffer + init = self.alloc_status() + # clear overflow buffer + self.clear_status(init) + + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + # DP clip + weights = self.weights + record_datas = self._split(data) + record_labels = self._split(label) + grads = () + # first index + loss = self.network(record_datas[0], record_labels[0]) + scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) + record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) + record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) + + grad_sum = list(record_grad) + grad_len = len(record_grad) + for i in range(grad_len): + grad_sum[i] = grad_sum[i].asnumpy() + + for i in range(1, self._micro_batches): + loss = self.network(record_datas[i], record_labels[i]) + scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss)) + record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) + record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) + for j in range(grad_len): + grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy() + + for i in range(grad_len): + grad_sum[i] = Tensor(grad_sum[i], ms.float32) + grads = tuple(grad_sum) + loss = self.network(data, label) + + grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + # get the overflow buffer + if not self.gpu_target: + self.get_status(init) + # sum overflow buffer elements, 0:not overflow , >0:overflow + flag_sum = self.reduce_sum(init, (0,)) + else: + flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) + flag_sum = self.addn(flag_sum) + # convert flag_sum to scalar + flag_sum = self.reshape(flag_sum, (())) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + # if there is no overflow, do optimize + if overflow: + opt = False + else: + opt = self.optimizer(grads) + ret = (loss, cond, scaling_sens) + return F.depend(ret, opt) + + +class _TrainOneStepCell(Cell): + r""" + Network training package class. + + Wraps the network with an optimizer. The resulting Cell be trained with input data and label. + Backward graph will be created in the construct function to do parameter updating. Different + parallel modes are available to run the training. + + Args: + network (Cell): The training network. + optimizer (Cell): Optimizer for updating the weights. + sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. + micro_batches (int): The number of small batches split from an origianl batch. Default: None. + l2_norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None. + mech (Mechanisms): The object can generate the different type of noise. Default: None. + + Inputs: + - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + + Outputs: + Tensor, a scalar Tensor with shape :math:`()`. + + Examples: + >>> net = Net() + >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() + >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> loss_net = nn.WithLossCell(net, loss_fn) + >>> train_net = nn.TrainOneStepCell(loss_net, optim) + """ + + def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=None, mech=None): + super(_TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.add_flags(defer_inline=True) + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + parallel_mode = _get_parallel_mode() + if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): + self.reducer_flag = True + if self.reducer_flag: + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + # dp params + self._micro_batches = micro_batches + self._l2_norm = l2_norm_clip + self._split = P.Split(0, self._micro_batches) + self._clip_by_global_norm = _ClipGradients() + self._mech = mech + + def construct(self, data, label): + """ + construct a compute flow. + """ + weights = self.weights + record_datas = self._split(data) + record_labels = self._split(label) + loss = self.network(record_datas[0], record_labels[0]) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) + record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) + grad_sum = list(record_grad) + grad_len = len(record_grad) + for i in range(grad_len): + grad_sum[i] = grad_sum[i].asnumpy() + + for i in range(1, self._micro_batches): + loss = self.network(record_datas[i], record_labels[i]) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) + record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) + for j in range(grad_len): + grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy() + + for i in range(grad_len): + grad_sum[i] = Tensor(grad_sum[i], ms.float32) + grads = tuple(grad_sum) + loss = self.network(data, label) + + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads)) diff --git a/tests/ut/python/diff_privacy/mechanisms/test_mechanisms.py b/tests/ut/python/diff_privacy/test_mechanisms.py similarity index 100% rename from tests/ut/python/diff_privacy/mechanisms/test_mechanisms.py rename to tests/ut/python/diff_privacy/test_mechanisms.py diff --git a/tests/ut/python/diff_privacy/test_model_train.py b/tests/ut/python/diff_privacy/test_model_train.py new file mode 100644 index 0000000..c20bb84 --- /dev/null +++ b/tests/ut/python/diff_privacy/test_model_train.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +DP-Model test. +""" +import pytest +import numpy as np + +from mindspore import nn +from mindspore.nn import SGD +from mindspore.model_zoo.lenet import LeNet5 +from mindspore import context +import mindspore.dataset as ds + +from mindarmour.diff_privacy import DPOptimizerClassFactory +from mindarmour.diff_privacy import DPModel + + +def dataset_generator(batch_size, batches): + data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) + label = np.random.randint(0, 10, batches * batch_size).astype(np.int32) + for i in range(batches): + yield data[i * batch_size:(i + 1) * batch_size], label[i * batch_size:(i + 1) * batch_size] + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_dp_model(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + l2_norm_bound = 1.0 + initial_noise_multiplier = 0.01 + net = LeNet5() + batch_size = 32 + batches = 128 + epochs = 1 + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + optim = SGD(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + gaussian_mech = DPOptimizerClassFactory() + gaussian_mech.set_mechanisms('Gaussian', + norm_bound=l2_norm_bound, + initial_noise_multiplier=initial_noise_multiplier) + model = DPModel(micro_batches=2, + norm_clip=l2_norm_bound, + dp_mech=gaussian_mech.mech, + network=net, + loss_fn=loss, + optimizer=optim, + metrics=None) + ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) + ms_ds.set_dataset_size(batch_size * batches) + model.train(epochs, ms_ds) diff --git a/tests/ut/python/diff_privacy/test_monitor.py b/tests/ut/python/diff_privacy/test_monitor.py index 1b76328..530d0b8 100644 --- a/tests/ut/python/diff_privacy/test_monitor.py +++ b/tests/ut/python/diff_privacy/test_monitor.py @@ -23,7 +23,7 @@ from mindspore.train import Model import mindspore.context as context from mindspore.model_zoo.lenet import LeNet5 -from mindarmour.diff_privacy.monitor.monitor import PrivacyMonitorFactory +from mindarmour.diff_privacy import PrivacyMonitorFactory from mindarmour.utils.logger import LogUtil LOGGER = LogUtil.get_instance() diff --git a/tests/ut/python/diff_privacy/test_optimizer.py b/tests/ut/python/diff_privacy/test_optimizer.py new file mode 100644 index 0000000..978e916 --- /dev/null +++ b/tests/ut/python/diff_privacy/test_optimizer.py @@ -0,0 +1,76 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from mindspore import nn +from mindspore import context +from mindspore.model_zoo.lenet import LeNet5 +from mindspore.train.model import Model + +from mindarmour.diff_privacy import DPOptimizerClassFactory + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_optimizer(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + network = LeNet5() + lr = 0.01 + momentum = 0.9 + micro_batches = 2 + loss = nn.SoftmaxCrossEntropyWithLogits() + gaussian_mech = DPOptimizerClassFactory(micro_batches) + gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) + net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr, + momentum=momentum) + _ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_inference +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_optimizer_gpu(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + network = LeNet5() + lr = 0.01 + momentum = 0.9 + micro_batches = 2 + loss = nn.SoftmaxCrossEntropyWithLogits() + gaussian_mech = DPOptimizerClassFactory(micro_batches) + gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) + net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr, + momentum=momentum) + _ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_optimizer_cpu(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + network = LeNet5() + lr = 0.01 + momentum = 0.9 + micro_batches = 2 + loss = nn.SoftmaxCrossEntropyWithLogits() + gaussian_mech = DPOptimizerClassFactory(micro_batches) + gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) + net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr, + momentum=momentum) + _ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None)