Browse Source

suppress based privacy model, 2021.2.9

tags/v1.2.0-rc1
itcomee 4 years ago
parent
commit
d0817b351a
4 changed files with 46 additions and 39 deletions
  1. +5
    -5
      examples/privacy/sup_privacy/sup_privacy.py
  2. +37
    -25
      mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py
  3. +3
    -8
      mindarmour/privacy/sup_privacy/train/model.py
  4. +1
    -1
      tests/ut/python/privacy/sup_privacy/test_model_train.py

+ 5
- 5
examples/privacy/sup_privacy/sup_privacy.py View File

@@ -141,11 +141,11 @@ if __name__ == "__main__":


masklayers_lenet5 = [] # determine which layer should be masked masklayers_lenet5 = [] # determine which layer should be masked


masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, True, 10))
masklayers_lenet5.append(MaskLayerDes("conv2.weight", False, True, 150))
masklayers_lenet5.append(MaskLayerDes("fc1.weight", True, False, -1))
masklayers_lenet5.append(MaskLayerDes("fc2.weight", True, False, -1))
masklayers_lenet5.append(MaskLayerDes("fc3.weight", True, False, 50))
masklayers_lenet5.append(MaskLayerDes("conv1.weight", 0, False, True, 10))
masklayers_lenet5.append(MaskLayerDes("conv2.weight", 1, False, True, 150))
masklayers_lenet5.append(MaskLayerDes("fc1.weight", 2, True, False, -1))
masklayers_lenet5.append(MaskLayerDes("fc2.weight", 4, True, False, -1))
masklayers_lenet5.append(MaskLayerDes("fc3.weight", 6, True, False, 50))


# do suppreess privacy train, with stronger privacy protection and better performance than Differential Privacy # do suppreess privacy train, with stronger privacy protection and better performance than Differential Privacy
mnist_suppress_train(10, 3, 0.10, 60000, 1000, 0.95, 0.0, masklayers=masklayers_lenet5) # used mnist_suppress_train(10, 3, 0.10, 60000, 1000, 0.95, 0.0, masklayers=masklayers_lenet5) # used

+ 37
- 25
mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py View File

@@ -184,16 +184,32 @@ class SuppressCtrl(Cell):
layer_name = one_mask_layer.layer_name layer_name = one_mask_layer.layer_name
mask_layer_id2 = 0 mask_layer_id2 = 0
for one_mask_layer_2 in mask_layers: for one_mask_layer_2 in mask_layers:
if mask_layer_id != mask_layer_id2 and layer_name in one_mask_layer_2.layer_name:
msg = "mask_layers repeat item : {} in {} and {}".format(layer_name,
mask_layer_id,
mask_layer_id2)
if mask_layer_id != mask_layer_id2 and layer_name == one_mask_layer_2.layer_name:
msg = "Mask layer name should be unique, but got duplicate name: {} in mask_layer {} and {}".\
format(layer_name, mask_layer_id, mask_layer_id2)
LOGGER.error(TAG, msg)
raise ValueError(msg)
if mask_layer_id != mask_layer_id2 and one_mask_layer.grad_idx == one_mask_layer_2.grad_idx:
msg = "Grad_idx should be unique, but got duplicate idx: {} in mask_layer {} and {}".\
format(layer_name, one_mask_layer_2.layer_name, one_mask_layer.grad_idx)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)
mask_layer_id2 = mask_layer_id2 + 1 mask_layer_id2 = mask_layer_id2 + 1
mask_layer_id = mask_layer_id + 1 mask_layer_id = mask_layer_id + 1


if networks is not None: if networks is not None:
for layer in networks.get_parameters(expand=True):
shape = np.shape([1])
mul_mask_array = np.ones(shape, dtype=np.float32)
grad_mask_cell = GradMaskInCell(mul_mask_array, False, False, -1)
grad_mask_cell.mask_able = False
self.grads_mask_list.append(grad_mask_cell)

add_mask_array = np.zeros(shape, dtype=np.float32)
de_weight_cell = DeWeightInCell(add_mask_array)
de_weight_cell.mask_able = False
self.de_weight_mask_list.append(de_weight_cell)

m = 0 m = 0
for layer in networks.get_parameters(expand=True): for layer in networks.get_parameters(expand=True):
one_mask_layer = None one_mask_layer = None
@@ -209,29 +225,18 @@ class SuppressCtrl(Cell):
one_mask_layer.min_num, one_mask_layer.min_num,
one_mask_layer.upper_bound) one_mask_layer.upper_bound)
grad_mask_cell.mask_able = True grad_mask_cell.mask_able = True
self.grads_mask_list.append(grad_mask_cell)
add_mask_array = np.zeros(shape, dtype=np.float32)
self.grads_mask_list[one_mask_layer.grad_idx] = grad_mask_cell


add_mask_array = np.zeros(shape, dtype=np.float32)
de_weight_cell = DeWeightInCell(add_mask_array) de_weight_cell = DeWeightInCell(add_mask_array)
de_weight_cell.mask_able = True de_weight_cell.mask_able = True
self.de_weight_mask_list.append(de_weight_cell)
msg = "do mask {}, {}".format(m, one_mask_layer.layer_name)
self.de_weight_mask_list[one_mask_layer.grad_idx] = de_weight_cell
msg = "do mask {}, {}, {}".format(m, one_mask_layer.layer_name, one_mask_layer.grad_idx)
LOGGER.info(TAG, msg) LOGGER.info(TAG, msg)
elif one_mask_layer is not None and one_mask_layer.inited: elif one_mask_layer is not None and one_mask_layer.inited:
msg = "repeated match masked setting {}=>{}.".format(one_mask_layer.layer_name, layer.name) msg = "repeated match masked setting {}=>{}.".format(one_mask_layer.layer_name, layer.name)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)
else:
shape = np.shape([1])
mul_mask_array = np.ones(shape, dtype=np.float32)
grad_mask_cell = GradMaskInCell(mul_mask_array, False, False, -1)
grad_mask_cell.mask_able = False

self.grads_mask_list.append(grad_mask_cell)
add_mask_array = np.zeros(shape, dtype=np.float32)
de_weight_cell = DeWeightInCell(add_mask_array)
de_weight_cell.mask_able = False
self.de_weight_mask_list.append(de_weight_cell)
m += 1 m += 1
self.mask_initialized = True self.mask_initialized = True
msg = "init SuppressCtrl by networks" msg = "init SuppressCtrl by networks"
@@ -555,7 +560,7 @@ def get_one_mask_layer(mask_layers, layer_name):
Union[MaskLayerDes, None], the layer definitions that need to be suppressed. Union[MaskLayerDes, None], the layer definitions that need to be suppressed.
""" """
for each_mask_layer in mask_layers: for each_mask_layer in mask_layers:
if each_mask_layer.layer_name in layer_name:
if each_mask_layer.layer_name in layer_name and not each_mask_layer.inited:
return each_mask_layer return each_mask_layer
return None return None


@@ -567,15 +572,21 @@ class MaskLayerDes:
layer_name (str): Layer name, get the name of one layer as following: layer_name (str): Layer name, get the name of one layer as following:
for layer in networks.get_parameters(expand=True): for layer in networks.get_parameters(expand=True):
if layer.name == "conv": ... if layer.name == "conv": ...
grad_idx (int): Grad layer index, get mask layer's index in grad tuple.You can refer to the construct function
of TrainOneStepCell in mindarmour/privacy/sup_privacy/train/model.py to get the index of some specified
grad layers (print in PYNATIVE_MODE).
is_add_noise (bool): If True, the weight of this layer can add noise. is_add_noise (bool): If True, the weight of this layer can add noise.
If False, the weight of this layer can not add noise. If False, the weight of this layer can not add noise.
is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value. is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value.
If False, the weights of this layer won't be clipped. If False, the weights of this layer won't be clipped.
min_num (int): The number of weights left that not be suppressed, which need to be greater than 0.
upper_bound (float): max value of weight in this layer, default value is 1.20 .
min_num (int): The number of weights left that not be suppressed.
If min_num is smaller than (parameter num*SupperssCtrl.sparse_end), min_num has not effect.
upper_bound (Union[float, int]): max abs value of weight in this layer, default: 1.20.
""" """
def __init__(self, layer_name, is_add_noise, is_lower_clip, min_num, upper_bound=1.20):
def __init__(self, layer_name, grad_idx, is_add_noise, is_lower_clip, min_num, upper_bound=1.20):
self.layer_name = check_param_type('layer_name', layer_name, str) self.layer_name = check_param_type('layer_name', layer_name, str)
check_param_type('grad_idx', grad_idx, int)
self.grad_idx = check_value_non_negative('grad_idx', grad_idx)
self.is_add_noise = check_param_type('is_add_noise', is_add_noise, bool) self.is_add_noise = check_param_type('is_add_noise', is_add_noise, bool)
self.is_lower_clip = check_param_type('is_lower_clip', is_lower_clip, bool) self.is_lower_clip = check_param_type('is_lower_clip', is_lower_clip, bool)
self.min_num = check_param_type('min_num', min_num, int) self.min_num = check_param_type('min_num', min_num, int)
@@ -592,8 +603,9 @@ class GradMaskInCell(Cell):
If False, the weight of this layer can not add noise. If False, the weight of this layer can not add noise.
is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value. is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value.
If False, the weights of this layer won't be clipped. If False, the weights of this layer won't be clipped.
min_num (int): The number of weights left that not be suppressed, which need to be greater than 0.
upper_bound (float): max value of weight in this layer, default value is 1.20
min_num (int): The number of weights left that not be suppressed.
If min_num is smaller than (parameter num*SupperssCtrl.sparse_end), min_num has not effect.
upper_bound ([float, int]): max abs value of weight in this layer, default: 1.20.
""" """
def __init__(self, array, is_add_noise, is_lower_clip, min_num, upper_bound=1.20): def __init__(self, array, is_add_noise, is_lower_clip, min_num, upper_bound=1.20):
super(GradMaskInCell, self).__init__() super(GradMaskInCell, self).__init__()


+ 3
- 8
mindarmour/privacy/sup_privacy/train/model.py View File

@@ -113,11 +113,7 @@ class SuppressModel(Model):
Args: Args:
suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance. suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance.
""" """
check_param_type('suppress_pri_ctrl', suppress_pri_ctrl, Cell)
if not isinstance(suppress_pri_ctrl, SuppressCtrl):
msg = "SuppressCtrl instance error!"
LOGGER.error(TAG, msg)
raise ValueError(msg)
check_param_type('suppress_pri_ctrl', suppress_pri_ctrl, SuppressCtrl)


suppress_pri_ctrl.model = self suppress_pri_ctrl.model = self
if self._train_one_step is not None: if self._train_one_step is not None:
@@ -214,7 +210,6 @@ class _TupleAdd(nn.Cell):
out = self.hyper_map(self.add, input1, input2) out = self.hyper_map(self.add, input1, input2)
return out return out



class _TupleMul(nn.Cell): class _TupleMul(nn.Cell):
""" """
Mul two tuple of data. Mul two tuple of data.
@@ -258,7 +253,7 @@ class TrainOneStepCell(Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) # for mindspore 0.7x
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
@@ -268,7 +263,7 @@ class TrainOneStepCell(Cell):
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = _get_gradients_mean() # for mindspore 0.7x
mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)




+ 1
- 1
tests/ut/python/privacy/sup_privacy/test_model_train.py View File

@@ -57,7 +57,7 @@ def test_suppress_model_with_pynative_mode():
mask_times = 10 mask_times = 10
lr = 0.01 lr = 0.01
masklayers_lenet5 = [] masklayers_lenet5 = []
masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, False, -1))
masklayers_lenet5.append(MaskLayerDes("conv1.weight", 0, False, False, -1))
suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5, suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5,
masklayers_lenet5, masklayers_lenet5,
policy="local_train", policy="local_train",


Loading…
Cancel
Save