Browse Source

!50 add exponential noise decay

Merge pull request !50 from zheng-huanhuan/master
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
ac39d193bb
4 changed files with 127 additions and 70 deletions
  1. +17
    -12
      mindarmour/diff_privacy/mechanisms/mechanisms.py
  2. +3
    -2
      mindarmour/diff_privacy/optimizer/optimizer.py
  3. +59
    -50
      mindarmour/diff_privacy/train/model.py
  4. +48
    -6
      tests/ut/python/diff_privacy/test_mechanisms.py

+ 17
- 12
mindarmour/diff_privacy/mechanisms/mechanisms.py View File

@@ -214,8 +214,8 @@ class AdaGaussianRandom(Mechanisms):
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float)
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0)
self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32)
if decay_policy not in ['Time', 'Step']:
raise NameError("The decay_policy must be in ['Time', 'Step'], but "
if decay_policy not in ['Time', 'Step', 'Exp']:
raise NameError("The decay_policy must be in ['Time', 'Step', 'Exp'], but "
"get {}".format(decay_policy))
self._decay_policy = decay_policy
self._mul = P.Mul()
@@ -243,18 +243,18 @@ class _MechanismsParamsUpdater(Cell):
Args:
policy(str): Pass in by the mechanisms class, mechanisms parameters update policy.
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for controlling the decay size.
cur_params(Parameter): Pass in by the mechanisms class, current params value in this time.
init_params(Parameter):Pass in by the mechanisms class, initial params value to be updated.
cur_noise_multiplier(Parameter): Pass in by the mechanisms class, current params value in this time.
init_noise_multiplier(Parameter):Pass in by the mechanisms class, initial params value to be updated.

Returns:
Tuple, next params value.
"""
def __init__(self, policy, decay_rate, cur_params, init_params):
def __init__(self, policy, decay_rate, cur_noise_multiplier, init_noise_multiplier):
super(_MechanismsParamsUpdater, self).__init__()
self._policy = policy
self._decay_rate = decay_rate
self._cur_params = cur_params
self._init_params = init_params
self._cur_noise_multiplier = cur_noise_multiplier
self._init_noise_multiplier = init_noise_multiplier

self._div = P.Sub()
self._add = P.TensorAdd()
@@ -262,6 +262,7 @@ class _MechanismsParamsUpdater(Cell):
self._sub = P.Sub()
self._one = Tensor(1, mstype.float32)
self._mul = P.Mul()
self._exp = P.Exp()

def construct(self):
"""
@@ -271,10 +272,14 @@ class _MechanismsParamsUpdater(Cell):
Tuple, next step parameters value.
"""
if self._policy == 'Time':
temp = self._div(self._init_params, self._cur_params)
temp = self._div(self._init_noise_multiplier, self._cur_noise_multiplier)
temp = self._add(temp, self._decay_rate)
next_params = self._assign(self._cur_params, self._div(self._init_params, temp))
else:
next_noise_multiplier = self._assign(self._cur_noise_multiplier,
self._div(self._init_noise_multiplier, temp))
elif self._policy == 'Step':
temp = self._sub(self._one, self._decay_rate)
next_params = self._assign(self._cur_params, self._mul(temp, self._cur_params))
return next_params
next_noise_multiplier = self._assign(self._cur_noise_multiplier,
self._mul(temp, self._cur_noise_multiplier))
else:
next_noise_multiplier = self._assign(self._cur_noise_multiplier, self._div(self._one, self._exp(self._one)))
return next_noise_multiplier

+ 3
- 2
mindarmour/diff_privacy/optimizer/optimizer.py View File

@@ -130,8 +130,9 @@ class DPOptimizerClassFactory:
if self._mech is not None and self._mech._decay_policy is not None:
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy,
decay_rate=self._mech._noise_decay_rate,
cur_params=self._mech._noise_multiplier,
init_params=
cur_noise_multiplier=
self._mech._noise_multiplier,
init_noise_multiplier=
self._mech._initial_noise_multiplier)

def construct(self, gradients):


+ 59
- 50
mindarmour/diff_privacy/train/model.py View File

@@ -195,48 +195,47 @@ class DPModel(Model):
mech=self._mech).set_train()
return network


def _build_train_network(self):
"""Build train network"""
network = self._network
if self._micro_batches:
if self._optimizer:
if self._loss_scale_manager_set:
network = self._amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = self._amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
else:
if self._optimizer:
if self._loss_scale_manager_set:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)

if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return network
def _build_train_network(self):
"""Build train network"""
network = self._network
if self._micro_batches:
if self._optimizer:
if self._loss_scale_manager_set:
network = self._amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = self._amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
else:
if self._optimizer:
if self._loss_scale_manager_set:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)

if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return network


class _ClipGradients(nn.Cell):
@@ -376,8 +375,10 @@ class _TrainOneStepWithLossScaleCell(Cell):
if self._mech is not None and self._mech._decay_policy is not None:
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy,
decay_rate=self._mech._noise_decay_rate,
cur_params=self._mech._noise_multiplier,
init_params=self._mech._initial_noise_multiplier)
cur_noise_multiplier=
self._mech._noise_multiplier,
init_noise_multiplier=
self._mech._initial_noise_multiplier)

def construct(self, data, label, sens=None):
"""
@@ -416,8 +417,11 @@ class _TrainOneStepWithLossScaleCell(Cell):
loss = P.Div()(total_loss, self._micro_float)

if self._mech is not None:
grad_noise = self._hyper_map(self._mech, grads)
grads = self._tuple_add(grads, grad_noise)
grad_noise_tuple = ()
for grad_item in grads:
grad_noise = self._mech(grad_item)
grad_noise_tuple = grad_noise_tuple + (grad_noise,)
grads = self._tuple_add(grads, grad_noise_tuple)
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads)
# update mech parameters
if self._mech_param_updater is not None:
@@ -517,8 +521,10 @@ class _TrainOneStepCell(Cell):
if self._mech is not None and self._mech._decay_policy is not None:
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy,
decay_rate=self._mech._noise_decay_rate,
cur_params=self._mech._noise_multiplier,
init_params=self._mech._initial_noise_multiplier)
cur_noise_multiplier=
self._mech._noise_multiplier,
init_noise_multiplier=
self._mech._initial_noise_multiplier)

def construct(self, data, label):
"""
@@ -543,8 +549,11 @@ class _TrainOneStepCell(Cell):
loss = P.Div()(total_loss, self._micro_float)

if self._mech is not None:
grad_noise = self._hyper_map(self._mech, grads)
grads = self._tuple_add(grads, grad_noise)
grad_noise_tuple = ()
for grad_item in grads:
grad_noise = self._mech(grad_item)
grad_noise_tuple = grad_noise_tuple + (grad_noise,)
grads = self._tuple_add(grads, grad_noise_tuple)
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads)
# update mech parameters
if self._mech_param_updater is not None:


+ 48
- 6
tests/ut/python/diff_privacy/test_mechanisms.py View File

@@ -30,7 +30,7 @@ from mindarmour.diff_privacy import MechanismsFactory
@pytest.mark.component_mindarmour
def test_graph_gaussian():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad = Tensor([3, 2, 4], mstype.float32)
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
net = GaussianRandom(norm_bound, initial_noise_multiplier)
@@ -44,7 +44,7 @@ def test_graph_gaussian():
@pytest.mark.component_mindarmour
def test_pynative_gaussian():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad = Tensor([3, 2, 4], mstype.float32)
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
net = GaussianRandom(norm_bound, initial_noise_multiplier)
@@ -58,7 +58,7 @@ def test_pynative_gaussian():
@pytest.mark.component_mindarmour
def test_graph_ada_gaussian():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad = Tensor([3, 2, 4], mstype.float32)
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
alpha = 0.5
@@ -75,7 +75,7 @@ def test_graph_ada_gaussian():
@pytest.mark.component_mindarmour
def test_graph_factory():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad = Tensor([3, 2, 4], mstype.float32)
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
alpha = 0.5
@@ -102,7 +102,7 @@ def test_graph_factory():
@pytest.mark.component_mindarmour
def test_pynative_ada_gaussian():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad = Tensor([3, 2, 4], mstype.float32)
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
alpha = 0.5
@@ -119,7 +119,7 @@ def test_pynative_ada_gaussian():
@pytest.mark.component_mindarmour
def test_pynative_factory():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad = Tensor([3, 2, 4], mstype.float32)
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
alpha = 0.5
@@ -138,3 +138,45 @@ def test_pynative_factory():
decay_policy=decay_policy)
ada_noise = ada_noise_construct(grad)
print('ada noise: ', ada_noise)


@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_pynative_exponential():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
alpha = 0.5
decay_policy = 'Exp'
ada_mechanism = MechanismsFactory()
ada_noise_construct = ada_mechanism.create('AdaGaussian',
norm_bound,
initial_noise_multiplier,
noise_decay_rate=alpha,
decay_policy=decay_policy)
ada_noise = ada_noise_construct(grad)
print('ada noise: ', ada_noise)


@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_graph_exponential():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
alpha = 0.5
decay_policy = 'Exp'
ada_mechanism = MechanismsFactory()
ada_noise_construct = ada_mechanism.create('AdaGaussian',
norm_bound,
initial_noise_multiplier,
noise_decay_rate=alpha,
decay_policy=decay_policy)
ada_noise = ada_noise_construct(grad)
print('ada noise: ', ada_noise)

Loading…
Cancel
Save