Browse Source

!165 Fix some issues for Suppress Privacy - modified on 2021.2.3

From: @itcomee
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
feb920c1b8
5 changed files with 38 additions and 37 deletions
  1. +4
    -8
      examples/privacy/sup_privacy/sup_privacy.py
  2. +1
    -1
      examples/privacy/sup_privacy/sup_privacy_config.py
  3. +28
    -23
      mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py
  4. +1
    -1
      mindarmour/privacy/sup_privacy/train/model.py
  5. +4
    -4
      tests/ut/python/privacy/sup_privacy/test_model_train.py

+ 4
- 8
examples/privacy/sup_privacy/sup_privacy.py View File

@@ -21,7 +21,6 @@ from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig from mindspore.train.callback import CheckpointConfig
from mindspore.train.callback import LossMonitor from mindspore.train.callback import LossMonitor
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
@@ -91,16 +90,16 @@ def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, m
""" """


networks_l5 = LeNet5() networks_l5 = LeNet5()
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train",
suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5,
masklayers,
policy="local_train",
end_epoch=epoch_size, end_epoch=epoch_size,
batch_num=(int)(samples/cfg.batch_size), batch_num=(int)(samples/cfg.batch_size),
start_epoch=start_epoch, start_epoch=start_epoch,
mask_times=mask_times, mask_times=mask_times,
networks=networks_l5,
lr=lr, lr=lr,
sparse_end=sparse_thd, sparse_end=sparse_thd,
sparse_start=sparse_start,
mask_layers=masklayers)
sparse_start=sparse_start)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.SGD(networks_l5.trainable_params(), lr) net_opt = nn.SGD(networks_l5.trainable_params(), lr)
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size),
@@ -130,9 +129,6 @@ def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, m
dataset_sink_mode=False) dataset_sink_mode=False)


print("============== Starting SUPP Testing ==============") print("============== Starting SUPP Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(networks_l5, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(mnist_path, 'test'), ds_eval = generate_mnist_dataset(os.path.join(mnist_path, 'test'),
batch_size=cfg.batch_size) batch_size=cfg.batch_size)
acc = model_instance.eval(ds_eval, dataset_sink_mode=False) acc = model_instance.eval(ds_eval, dataset_sink_mode=False)


+ 1
- 1
examples/privacy/sup_privacy/sup_privacy_config.py View File

@@ -20,7 +20,7 @@ from easydict import EasyDict as edict


mnist_cfg = edict({ mnist_cfg = edict({
'num_classes': 10, # the number of classes of model's output 'num_classes': 10, # the number of classes of model's output
'epoch_size': 1, # training epochs
'epoch_size': 10, # training epochs
'batch_size': 32, # batch size for training 'batch_size': 32, # batch size for training
'image_height': 32, # the height of training samples 'image_height': 32, # the height of training samples
'image_width': 32, # the width of training samples 'image_width': 32, # the width of training samples


+ 28
- 23
mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py View File

@@ -35,20 +35,20 @@ class SuppressPrivacyFactory:
pass pass


@staticmethod @staticmethod
def create(policy="local_train", end_epoch=10, batch_num=2, start_epoch=3, mask_times=100, networks=None,
lr=0.05, sparse_end=0.60, sparse_start=0.0, mask_layers=None):
def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3,
mask_times=500, lr=0.10, sparse_end=0.90, sparse_start=0.0):
""" """
Args: Args:
policy (str): Training policy for suppress privacy training. "local_train" means local training.
end_epoch (int): The last epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 .
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size .
start_epoch (int): The first epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 .
mask_times (int): The num of suppress operations.
networks (Cell): The training network. networks (Cell): The training network.
lr (float): Learning rate.
sparse_end (float): The sparsity to reach, 0.0 <= sparse_start < sparse_end < 1.0 .
sparse_start (float): The sparsity to start, 0.0 <= sparse_start < sparse_end < 1.0 .
mask_layers (list): Description of the training network layers that need to be suppressed. mask_layers (list): Description of the training network layers that need to be suppressed.
policy (str): Training policy for suppress privacy training. Default: "local_train", means local training.
end_epoch (int): The last epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 10.
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size. Default: 20.
start_epoch (int): The first epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 3.
mask_times (int): The num of suppress operations. Default: 500.
lr (Union[float, int]): Learning rate, 0 < lr <= 0.5. Default: 0.10.
sparse_end (float): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90.
sparse_start (float): The sparsity to start, 0.0<=sparse_start<sparse_end<1.0. Default: 0.0.


Returns: Returns:
SuppressCtrl, class of Suppress Privavy Mechanism. SuppressCtrl, class of Suppress Privavy Mechanism.
@@ -84,8 +84,8 @@ class SuppressPrivacyFactory:
dataset_sink_mode=False) dataset_sink_mode=False)
""" """
if policy == "local_train": if policy == "local_train":
return SuppressCtrl(networks, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end,
sparse_start, mask_layers)
return SuppressCtrl(networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr,
sparse_end, sparse_start)
msg = "Only local training is supported now, federal training will be supported " \ msg = "Only local training is supported now, federal training will be supported " \
"in the future. But got {}.".format(policy) "in the future. But got {}.".format(policy)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
@@ -95,6 +95,7 @@ class SuppressCtrl(Cell):
""" """
Args: Args:
networks (Cell): The training network. networks (Cell): The training network.
mask_layers (list): Description of those layers that need to be suppressed.
end_epoch (int): The last epoch in suppress operations. end_epoch (int): The last epoch in suppress operations.
batch_num (int): The num of grad operation in an epoch. batch_num (int): The num of grad operation in an epoch.
mask_start_epoch (int): The first epoch in suppress operations. mask_start_epoch (int): The first epoch in suppress operations.
@@ -102,14 +103,12 @@ class SuppressCtrl(Cell):
lr (Union[float, int]): Learning rate. lr (Union[float, int]): Learning rate.
sparse_end (Union[float, int]): The sparsity to reach. sparse_end (Union[float, int]): The sparsity to reach.
sparse_start (float): The sparsity to start. sparse_start (float): The sparsity to start.
mask_layers (list): Description of those layers that need to be suppressed.
""" """
def __init__(self, networks, end_epoch, batch_num, mask_start_epoch=3, mask_times=500, lr=0.05,
sparse_end=0.60,
sparse_start=0.0,
mask_layers=None):
def __init__(self, networks, mask_layers, end_epoch, batch_num, mask_start_epoch, mask_times, lr,
sparse_end, sparse_start):
super(SuppressCtrl, self).__init__() super(SuppressCtrl, self).__init__()
self.networks = check_param_type('networks', networks, Cell) self.networks = check_param_type('networks', networks, Cell)
self.mask_layers = check_param_type('mask_layers', mask_layers, list)
self.mask_end_epoch = check_int_positive('end_epoch', end_epoch) self.mask_end_epoch = check_int_positive('end_epoch', end_epoch)
self.batch_num = check_int_positive('batch_num', batch_num) self.batch_num = check_int_positive('batch_num', batch_num)
self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch) self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch)
@@ -117,7 +116,6 @@ class SuppressCtrl(Cell):
self.lr = check_value_positive('lr', lr) self.lr = check_value_positive('lr', lr)
self.sparse_end = check_value_non_negative('sparse_end', sparse_end) self.sparse_end = check_value_non_negative('sparse_end', sparse_end)
self.sparse_start = check_value_non_negative('sparse_start', sparse_start) self.sparse_start = check_value_non_negative('sparse_start', sparse_start)
self.mask_layers = check_param_type('mask_layers', mask_layers, list)


self.weight_lower_bound = 0.005 # all network weight will be larger than this value self.weight_lower_bound = 0.005 # all network weight will be larger than this value
self.sparse_vibra = 0.02 # the sparsity may have certain range of variations self.sparse_vibra = 0.02 # the sparsity may have certain range of variations
@@ -137,13 +135,19 @@ class SuppressCtrl(Cell):
self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation
self.mask_initialized = False # flag means the initialization is done self.mask_initialized = False # flag means the initialization is done


if self.lr > 0.5:
msg = "learning rate should be smaller than 0.5, but got {}".format(self.lr)
LOGGER.error(TAG, msg)
raise ValueError(msg)

if self.mask_start_epoch > self.mask_end_epoch: if self.mask_start_epoch > self.mask_end_epoch:
msg = "start_epoch error: {}".format(self.mask_start_epoch)
msg = "start_epoch should not be greater than end_epoch, but got start_epoch and end_epoch are: " \
"{}, {}".format(self.mask_start_epoch, self.mask_end_epoch)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)


if self.mask_end_epoch > 100: if self.mask_end_epoch > 100:
msg = "end_epoch error: {}".format(self.mask_end_epoch)
msg = "The end_epoch should be smaller than 100, but got {}".format(self.mask_end_epoch)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)


@@ -152,13 +156,14 @@ class SuppressCtrl(Cell):
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)


if self.sparse_end > 1.00 or self.sparse_end <= 0:
msg = "sparse_end error: {}".format(self.sparse_end)
if self.sparse_end >= 1.00 or self.sparse_end <= 0:
msg = "sparse_end should be in range (0, 1), but got {}".format(self.sparse_end)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)


if self.sparse_start >= self.sparse_end: if self.sparse_start >= self.sparse_end:
msg = "sparse_start error: {}".format(self.sparse_start)
msg = "sparse_start should be smaller than sparse_end, but got sparse_start and sparse_end are: " \
"{}, {}".format(self.sparse_start, self.sparse_end)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)




+ 1
- 1
mindarmour/privacy/sup_privacy/train/model.py View File

@@ -96,7 +96,7 @@ class SuppressModel(Model):
""" """


def __init__(self, def __init__(self,
network=None,
network,
**kwargs): **kwargs):


check_param_type('networks', network, Cell) check_param_type('networks', network, Cell)


+ 4
- 4
tests/ut/python/privacy/sup_privacy/test_model_train.py View File

@@ -56,16 +56,16 @@ def test_suppress_model_with_pynative_mode():
lr = 0.01 lr = 0.01
masklayers_lenet5 = [] masklayers_lenet5 = []
masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, False, -1)) masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, False, -1))
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train",
suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5,
masklayers_lenet5,
policy="local_train",
end_epoch=epochs, end_epoch=epochs,
batch_num=batch_num, batch_num=batch_num,
start_epoch=1, start_epoch=1,
mask_times=mask_times, mask_times=mask_times,
networks=networks_l5,
lr=lr, lr=lr,
sparse_end=0.50, sparse_end=0.50,
sparse_start=0.0,
mask_layers=masklayers_lenet5)
sparse_start=0.0)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.SGD(networks_l5.trainable_params(), lr) net_opt = nn.SGD(networks_l5.trainable_params(), lr)
model_instance = SuppressModel( model_instance = SuppressModel(


Loading…
Cancel
Save