Browse Source

suppress based privacy model, 2021.3.16

tags/v1.2.0-rc1
itcomee 4 years ago
parent
commit
b783b9702e
3 changed files with 199 additions and 46 deletions
  1. +1
    -1
      examples/privacy/sup_privacy/sup_privacy.py
  2. +193
    -44
      mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py
  3. +5
    -1
      mindarmour/privacy/sup_privacy/train/model.py

+ 1
- 1
examples/privacy/sup_privacy/sup_privacy.py View File

@@ -148,4 +148,4 @@ if __name__ == "__main__":
masklayers_lenet5.append(MaskLayerDes("fc3.weight", 6, True, False, 50)) masklayers_lenet5.append(MaskLayerDes("fc3.weight", 6, True, False, 50))


# do suppreess privacy train, with stronger privacy protection and better performance than Differential Privacy # do suppreess privacy train, with stronger privacy protection and better performance than Differential Privacy
mnist_suppress_train(10, 3, 0.10, 60000, 1000, 0.95, 0.0, masklayers=masklayers_lenet5) # used
mnist_suppress_train(10, 3, 0.05, 60000, 1000, 0.95, 0.0, masklayers=masklayers_lenet5) # used

+ 193
- 44
mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py View File

@@ -15,6 +15,7 @@
control function of suppress-based privacy. control function of suppress-based privacy.
""" """
import math import math
import gc
import numpy as np import numpy as np


from mindspore import Tensor from mindspore import Tensor
@@ -35,18 +36,20 @@ class SuppressPrivacyFactory:


@staticmethod @staticmethod
def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3, def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3,
mask_times=1000, lr=0.10, sparse_end=0.90, sparse_start=0.0):
mask_times=1000, lr=0.05, sparse_end=0.90, sparse_start=0.0):
""" """
Args: Args:
networks (Cell): The training network. networks (Cell): The training network.
mask_layers (list): Description of the training network layers that need to be suppressed. mask_layers (list): Description of the training network layers that need to be suppressed.
policy (str): Training policy for suppress privacy training. Default: "local_train", means local training. policy (str): Training policy for suppress privacy training. Default: "local_train", means local training.
end_epoch (int): The last epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 10. end_epoch (int): The last epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 10.
This end_epoch parameter should be the same as 'epoch' parameter of mindspore.train.model.train().
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size. Default: 20. batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size. Default: 20.
start_epoch (int): The first epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 3. start_epoch (int): The first epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 3.
mask_times (int): The num of suppress operations. Default: 1000. mask_times (int): The num of suppress operations. Default: 1000.
lr (Union[float, int]): Learning rate, 0 < lr <= 0.5. Default: 0.10.
sparse_end (Union[float, int]): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90.
lr (Union[float, int]): Learning rate, should be unchanged during training. 0<lr<=0.50. Default: 0.05.
This lr parameter should be the same as 'learning_rate' parameter of mindspore.nn.SGD().
sparse_end (float): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90.
sparse_start (Union[float, int]): The sparsity to start, 0.0<=sparse_start<sparse_end<1.0. Default: 0.0. sparse_start (Union[float, int]): The sparsity to start, 0.0<=sparse_start<sparse_end<1.0. Default: 0.0.


Returns: Returns:
@@ -101,7 +104,7 @@ class SuppressCtrl(Cell):
start_epoch (int): The first epoch in suppress operations. start_epoch (int): The first epoch in suppress operations.
mask_times (int): The num of suppress operations. mask_times (int): The num of suppress operations.
lr (Union[float, int]): Learning rate. lr (Union[float, int]): Learning rate.
sparse_end (Union[float, int]): The sparsity to reach.
sparse_end (float): The sparsity to reach.
sparse_start (Union[float, int]): The sparsity to start. sparse_start (Union[float, int]): The sparsity to start.
""" """
def __init__(self, networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, def __init__(self, networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr,
@@ -114,12 +117,12 @@ class SuppressCtrl(Cell):
self.mask_start_epoch = check_int_positive('start_epoch', start_epoch) self.mask_start_epoch = check_int_positive('start_epoch', start_epoch)
self.mask_times = check_int_positive('mask_times', mask_times) self.mask_times = check_int_positive('mask_times', mask_times)
self.lr = check_value_positive('lr', lr) self.lr = check_value_positive('lr', lr)
self.sparse_end = check_value_non_negative('sparse_end', sparse_end)
self.sparse_end = check_param_type('sparse_end', sparse_end, float)
self.sparse_start = check_value_non_negative('sparse_start', sparse_start) self.sparse_start = check_value_non_negative('sparse_start', sparse_start)


self.weight_lower_bound = 0.005 # all network weight will be larger than this value self.weight_lower_bound = 0.005 # all network weight will be larger than this value
self.sparse_vibra = 0.02 # the sparsity may have certain range of variations self.sparse_vibra = 0.02 # the sparsity may have certain range of variations
self.sparse_valid_max_weight = 0.20 # if max network weight is less than this value, suppress operation stop temporarily
self.sparse_valid_max_weight = 0.02 # if max network weight is less than this value, suppress operation stop temporarily
self.add_noise_thd = 0.50 # if network weight is more than this value, noise is forced self.add_noise_thd = 0.50 # if network weight is more than this value, noise is forced
self.noise_volume = 0.1 # noise volume 0.1 self.noise_volume = 0.1 # noise volume 0.1
self.base_ground_thd = 0.0000001 # if network weight is less than this value, will be considered as 0 self.base_ground_thd = 0.0000001 # if network weight is less than this value, will be considered as 0
@@ -253,6 +256,10 @@ class SuppressCtrl(Cell):
msg = "can't match this mask layer: {} ".format(one_mask_layer.layer_name) msg = "can't match this mask layer: {} ".format(one_mask_layer.layer_name)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)
msg = "this lr parameter should be the same as 'learning_rate' parameter of mindspore.nn.SGD()\n"
msg += "this end_epoch parameter should be the same as 'epoch' parameter of mindspore.train.model.train()\n"
msg += "sup_privacy only support SGD optimizer"
LOGGER.info(TAG, msg)


def update_status(self, cur_epoch, cur_step, cur_step_in_epoch): def update_status(self, cur_epoch, cur_step, cur_step_in_epoch):
""" """
@@ -296,6 +303,7 @@ class SuppressCtrl(Cell):
self.cur_sparse = self.sparse_end +\ self.cur_sparse = self.sparse_end +\
(self.sparse_start - self.sparse_end)*\ (self.sparse_start - self.sparse_end)*\
math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3)
self.cur_sparse = min(self.cur_sparse, self.sparse_end)
m = 0 m = 0
for layer in networks.get_parameters(expand=True): for layer in networks.get_parameters(expand=True):
grad_idx = self.grad_idx_map[m] grad_idx = self.grad_idx_map[m]
@@ -303,31 +311,58 @@ class SuppressCtrl(Cell):
m = m + 1 m = m + 1
continue continue
if self.grads_mask_list[grad_idx].mask_able: if self.grads_mask_list[grad_idx].mask_able:
len_array = self.grads_mask_list[grad_idx].para_num
min_num = self.grads_mask_list[grad_idx].min_num
sparse_min_thd = 1.0 - min(min_num, len_array) / len_array
actual_stop_pos = int(len_array * min(sparse_min_thd, self.cur_sparse))

grad_mask_cell = self.grads_mask_list[grad_idx]
last_sparse_pos = grad_mask_cell.sparse_pos_list[-1]
if actual_stop_pos <= 0 or \
(actual_stop_pos < last_sparse_pos + grad_mask_cell.part_num and \
grad_mask_cell.is_approximity and m > 0):
sparse_weight_thd = 0
msg = "{} len={}, sparse={}, current sparse thd={}, [idle] \n" \
.format(layer.name, len_array, actual_stop_pos / len_array, sparse_weight_thd)
LOGGER.info(TAG, msg)
m = m + 1
continue

weight_array = layer.data.asnumpy() weight_array = layer.data.asnumpy()
weight_avg = np.mean(weight_array) weight_avg = np.mean(weight_array)
weight_array_flat = weight_array.flatten() weight_array_flat = weight_array.flatten()
weight_array_flat_abs = np.abs(weight_array_flat) weight_array_flat_abs = np.abs(weight_array_flat)
weight_abs_avg = np.mean(weight_array_flat_abs) weight_abs_avg = np.mean(weight_array_flat_abs)
weight_array_flat_abs.sort()
len_array = weight_array.size
weight_abs_max = np.max(weight_array_flat_abs) weight_abs_max = np.max(weight_array_flat_abs)
weight_abs_min = np.min(weight_array_flat_abs)

if m == 0 and weight_abs_max < self.sparse_valid_max_weight: if m == 0 and weight_abs_max < self.sparse_valid_max_weight:
msg = "give up this masking .."
msg = "layer 0 weight_abs_max = {}, give up this masking ... ".format(weight_abs_max)
LOGGER.info(TAG, msg) LOGGER.info(TAG, msg)
del weight_array_flat_abs
del weight_array_flat
del weight_array
gc.collect()
return return
if self.grads_mask_list[grad_idx].min_num > 0:
sparse_weight_thd, _, actual_stop_pos = self.calc_sparse_thd(weight_array_flat_abs,
self.cur_sparse, grad_idx)
else:
actual_stop_pos = int(len_array * self.cur_sparse)
sparse_weight_thd = weight_array_flat_abs[actual_stop_pos]


self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, grad_idx)
if grad_mask_cell.is_approximity and m > 0:
sparse_weight_thd = self.update_mask_layer_approximity(weight_array_flat, weight_array_flat_abs,
actual_stop_pos, grad_idx)
else:
partition = np.partition(weight_array_flat_abs, actual_stop_pos - 1)
sparse_weight_thd = partition[actual_stop_pos - 1]
self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos,
weight_abs_max, grad_idx)
del partition


msg = "{} len={}, sparse={}, current sparse thd={}, max={}, avg={}, avg_abs={} \n".format(
msg = "{} len={}, sparse={}, current sparse thd={}, max={}, min={}, avg={}, avg_abs={} \n".format(
layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd, layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd,
weight_abs_max, weight_avg, weight_abs_avg)
weight_abs_max, weight_abs_min, weight_avg, weight_abs_avg)
LOGGER.info(TAG, msg) LOGGER.info(TAG, msg)
del weight_array_flat_abs
del weight_array_flat
del weight_array
gc.collect()
m = m + 1 m = m + 1


def update_mask_layer(self, weight_array_flat, sparse_weight_thd, sparse_stop_pos, weight_abs_max, layer_index): def update_mask_layer(self, weight_array_flat, sparse_weight_thd, sparse_stop_pos, weight_abs_max, layer_index):
@@ -335,7 +370,7 @@ class SuppressCtrl(Cell):
Update add mask arrays and multiply mask arrays of one single layer. Update add mask arrays and multiply mask arrays of one single layer.


Args: Args:
weight_array (numpy.ndarray): The weight array of layer's parameters.
weight_array_flat (numpy.ndarray): The weight array of layer's parameters.
sparse_weight_thd (float): The weight threshold of sparse operation. sparse_weight_thd (float): The weight threshold of sparse operation.
sparse_stop_pos (int): The maximum number of elements to be suppressed. sparse_stop_pos (int): The maximum number of elements to be suppressed.
weight_abs_max (float): The maximum absolute value of weights. weight_abs_max (float): The maximum absolute value of weights.
@@ -358,9 +393,13 @@ class SuppressCtrl(Cell):
q = 0 q = 0
# add noise on weights if not masking or clipping. # add noise on weights if not masking or clipping.
weight_noise_bound = min(self.add_noise_thd, max(self.noise_volume*10, weight_abs_max*0.75)) weight_noise_bound = min(self.add_noise_thd, max(self.noise_volume*10, weight_abs_max*0.75))
for i in range(0, weight_array_flat.size):
if abs(weight_array_flat[i]) <= sparse_weight_thd:
if m < weight_array_flat.size - min_num and m < sparse_stop_pos:
size = self.grads_mask_list[layer_index].para_num
for i in range(0, size):
if mul_mask_array_flat[i] <= 0.0:
add_mask_array_flat[i] = weight_array_flat[i] / self.lr
m = m + 1
elif abs(weight_array_flat[i]) <= sparse_weight_thd:
if m < size - min_num and m < sparse_stop_pos:
# to mask # to mask
mul_mask_array_flat[i] = 0.0 mul_mask_array_flat[i] = 0.0
add_mask_array_flat[i] = weight_array_flat[i] / self.lr add_mask_array_flat[i] = weight_array_flat[i] / self.lr
@@ -368,9 +407,11 @@ class SuppressCtrl(Cell):
else: else:
# not mask # not mask
if weight_array_flat[i] > 0.0: if weight_array_flat[i] > 0.0:
add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr
add_mask_array_flat[i] = (weight_array_flat[i] \
- min(self.weight_lower_bound, sparse_weight_thd)) / self.lr
else: else:
add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr
add_mask_array_flat[i] = (weight_array_flat[i]
+ min(self.weight_lower_bound, sparse_weight_thd)) / self.lr
p = p + 1 p = p + 1
elif is_lower_clip and abs(weight_array_flat[i]) <= \ elif is_lower_clip and abs(weight_array_flat[i]) <= \
self.weight_lower_bound and sparse_weight_thd > self.weight_lower_bound*0.5: self.weight_lower_bound and sparse_weight_thd > self.weight_lower_bound*0.5:
@@ -404,28 +445,99 @@ class SuppressCtrl(Cell):
"suppressed elements, max-clip elements, min-clip elements and noised elements are {}, {}, {}, {}"\ "suppressed elements, max-clip elements, min-clip elements and noised elements are {}, {}, {}, {}"\
.format(len(grad_mask_cell.mul_mask_array_shape), layer_index, m, n, p, q) .format(len(grad_mask_cell.mul_mask_array_shape), layer_index, m, n, p, q)
LOGGER.info(TAG, msg) LOGGER.info(TAG, msg)
grad_mask_cell.sparse_pos_list.append(m)


def calc_sparse_thd(self, array_flat, sparse_value, layer_index):
def update_mask_layer_approximity(self, weight_array_flat, weight_array_flat_abs, actual_stop_pos, layer_index):
""" """
Calculate the suppression threshold of one weight array.
Update add mask arrays and multiply mask arrays of one single layer with many parameter.
disable clipping loweer, clipping, adding noise operation


Args: Args:
array_flat (numpy.ndarray): The flattened weight array.
sparse_value (float): The target sparse value of weight array.
weight_array_flat (numpy.ndarray): The weight array of layer's parameters.
weight_array_flat_abs (numpy.ndarray): The abs weight array of layer's parameters.
actual_stop_pos (int): The actually para num should be suppressed.
layer_index (int): The index of target layer.
"""
grad_mask_cell = self.grads_mask_list[layer_index]
mul_mask_array_flat = grad_mask_cell.mul_mask_array_flat
de_weight_cell = self.de_weight_mask_list[layer_index]
add_mask_array_flat = de_weight_cell.add_mask_array_flat


Returns:
- float, the sparse threshold of this array.
part_size = grad_mask_cell.part_size
part_num = grad_mask_cell.part_num
para_num = grad_mask_cell.para_num
init_batch_suppress = False


- int, the number of weight elements to be suppressed.
if not self.grads_mask_list[layer_index].mask_able:
return 0.0
real_part_num = 0
sparse_thd = 0.0
last_sparse_pos = grad_mask_cell.sparse_pos_list[-1]
split_k_num = max(0, int((actual_stop_pos - last_sparse_pos) / part_num))
if last_sparse_pos <= 0:
init_batch_suppress = True
for i in range(0, part_num):
array_row_mul_mask = mul_mask_array_flat[i * part_size : (i + 1) * part_size]
array_row_flat_abs = weight_array_flat_abs[i * part_size : (i + 1) * part_size]
if not init_batch_suppress:
array_row_flat_abs_masked = np.where(array_row_mul_mask <= 0.0, -1.0, array_row_flat_abs)
set_abs = set(array_row_flat_abs_masked)
set_abs.remove(-1.0)
list2 = list(set_abs)
val_array_align = np.array(list2)
del array_row_flat_abs_masked
del set_abs
del list2
gc.collect()
else:
val_array_align = array_row_flat_abs

real_split_k_num = min(split_k_num, len(val_array_align) - 1)
if real_split_k_num <= 0:
del array_row_flat_abs
del array_row_mul_mask
del val_array_align
gc.collect()
continue


- int, the larger number of weight elements to be suppressed.
"""
size = len(array_flat)
sparse_max_thd = 1.0 - min(self.grads_mask_list[layer_index].min_num, size) / size
pos = int(size*min(sparse_max_thd, sparse_value))
thd = array_flat[pos]
farther_stop_pos = int(size*min(sparse_max_thd, max(0, sparse_value + self.sparse_vibra / 2.0)))
return thd, pos, farther_stop_pos
partition = np.partition(val_array_align, real_split_k_num - 1)
sparse_k_thd = partition[real_split_k_num - 1]
if sparse_k_thd > 0 or init_batch_suppress:
real_part_num = real_part_num + 1
sparse_thd = sparse_thd + sparse_k_thd
del array_row_flat_abs
del array_row_mul_mask
del val_array_align
del partition
gc.collect()

if real_part_num > 0:
sparse_thd = sparse_thd / real_part_num
new_mul_mask_array_flat = np.where(weight_array_flat_abs <= sparse_thd, 0.0, 1.0)
grad_mask_cell.mul_mask_array_flat = new_mul_mask_array_flat
new_add_mask_array_flat = np.where(new_mul_mask_array_flat <= 0.0, weight_array_flat / self.lr, 0.0)
de_weight_cell.add_mask_array_flat = new_add_mask_array_flat
grad_mask_cell.update()
de_weight_cell.update()
del mul_mask_array_flat
del add_mask_array_flat
gc.collect()
real_suppress_num = para_num - int(np.sum(grad_mask_cell.mul_mask_array_flat))
grad_mask_cell.sparse_pos_list.append(real_suppress_num)
else:
real_suppress_num = 0

msg = "Dimension of mask tensor is {}D, which located in the {}-th layer of the network. " \
"\n The ideal number of suppressed elements is {}/{}/{}, real suppress elements is {}" \
.format(len(grad_mask_cell.mul_mask_array_shape), layer_index,
split_k_num, (actual_stop_pos - last_sparse_pos), actual_stop_pos, real_suppress_num)
LOGGER.info(TAG, msg)
if init_batch_suppress:
init_sparse_actual = real_suppress_num/para_num
print("init batch suppresss, actual sparse = {}".format(init_sparse_actual))

gc.collect()
return sparse_thd


def reset_zeros(self): def reset_zeros(self):
""" """
@@ -452,7 +564,6 @@ class SuppressCtrl(Cell):
if array_mul_mask_flat_conv1[i] <= 0.0: if array_mul_mask_flat_conv1[i] <= 0.0:
sparse += 1.0 sparse += 1.0
sparse_value_1 += 1.0 sparse_value_1 += 1.0

for i in range(0, array_mul_mask_flat_conv2.size): for i in range(0, array_mul_mask_flat_conv2.size):
full = full + 1.0 full = full + 1.0
full_conv2 = full_conv2 + 1.0 full_conv2 = full_conv2 + 1.0
@@ -483,10 +594,13 @@ class SuppressCtrl(Cell):
array_cur_conv1 = np.ones(np.shape([1]), dtype=np.float32) array_cur_conv1 = np.ones(np.shape([1]), dtype=np.float32)
array_cur_conv2 = np.ones(np.shape([1]), dtype=np.float32) array_cur_conv2 = np.ones(np.shape([1]), dtype=np.float32)
for layer in networks.get_parameters(expand=True): for layer in networks.get_parameters(expand=True):
if "conv1.weight" in layer.name:
if "networks.conv1.weight" in layer.name or "networks.layers.0.weight" in layer.name: # lenet5/res50 vgg16
array_cur_conv1 = layer.data.asnumpy() array_cur_conv1 = layer.data.asnumpy()
if "conv2.weight" in layer.name:
print("calc_actual_sparse, match conv1")
if "networks.conv2.weight" in layer.name or "networks.layers.3.weight" in layer.name \
or "networks.layer1.0.conv1.weight" in layer.name: # res50
array_cur_conv2 = layer.data.asnumpy() array_cur_conv2 = layer.data.asnumpy()
print("calc_actual_sparse, match conv2")


array_mul_mask_flat_conv1 = array_cur_conv1.flatten() array_mul_mask_flat_conv1 = array_cur_conv1.flatten()
array_mul_mask_flat_conv2 = array_cur_conv2.flatten() array_mul_mask_flat_conv2 = array_cur_conv2.flatten()
@@ -510,10 +624,15 @@ class SuppressCtrl(Cell):
sparse_value_2 = sparse_value_2 / full_conv2 sparse_value_2 = sparse_value_2 / full_conv2
msg = "conv sparse fact={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2) msg = "conv sparse fact={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2)
LOGGER.info(TAG, msg) LOGGER.info(TAG, msg)
del array_mul_mask_flat_conv1
del array_mul_mask_flat_conv2
del array_cur_conv1
del array_cur_conv2
gc.collect()
return sparse, sparse_value_1, sparse_value_2 return sparse, sparse_value_1, sparse_value_2


def calc_actual_sparse_for_fc1(self, networks): def calc_actual_sparse_for_fc1(self, networks):
self.calc_actual_sparse_for_layer(networks, "fc1.weight")
return self.calc_actual_sparse_for_layer(networks, "fc1.weight")


def calc_actual_sparse_for_layer(self, networks, layer_name): def calc_actual_sparse_for_layer(self, networks, layer_name):
""" """
@@ -533,11 +652,12 @@ class SuppressCtrl(Cell):
for layer in networks.get_parameters(expand=True): for layer in networks.get_parameters(expand=True):
if layer_name in layer.name: if layer_name in layer.name:
array_cur = layer.data.asnumpy() array_cur = layer.data.asnumpy()
break


if array_cur is None: if array_cur is None:
msg = "no such layer to calc sparse: {} ".format(layer_name) msg = "no such layer to calc sparse: {} ".format(layer_name)
LOGGER.info(TAG, msg) LOGGER.info(TAG, msg)
return
return 0.0


array_cur_flat = array_cur.flatten() array_cur_flat = array_cur.flatten()


@@ -549,6 +669,10 @@ class SuppressCtrl(Cell):
sparse = sparse / full sparse = sparse / full
msg = "{} sparse fact={} ".format(layer_name, sparse) msg = "{} sparse fact={} ".format(layer_name, sparse)
LOGGER.info(TAG, msg) LOGGER.info(TAG, msg)
del array_cur_flat
del array_cur
gc.collect()
return sparse


def print_paras(self): def print_paras(self):
msg = "paras: start_epoch:{}, end_epoch:{}, batch_num:{}, interval:{}, lr:{}, sparse_end:{}, sparse_start:{}" \ msg = "paras: start_epoch:{}, end_epoch:{}, batch_num:{}, interval:{}, lr:{}, sparse_end:{}, sparse_start:{}" \
@@ -631,6 +755,31 @@ class GradMaskInCell(Cell):
self.min_num = min_num self.min_num = min_num
self.upper_bound = check_value_positive('upper_bound', upper_bound) self.upper_bound = check_value_positive('upper_bound', upper_bound)


self.para_num = array.size
self.is_approximity = False
self.sparse_pos_list = [0]
self.part_num = 1
self.part_size = self.para_num
self.part_num_max = 16
self.para_many_num = 10000
self.para_huge_num = 10*10000*10000

if self.para_num > self.para_many_num:
self.is_approximity = True
self.is_add_noise = False
self.is_lower_clip = False

ratio = 2
if self.part_size > self.para_huge_num:
while self.part_size % ratio == 0 and self.part_size > self.para_huge_num \
and self.part_num < self.part_num_max:
self.part_num = self.part_num * ratio
self.part_size = int(self.part_size / ratio)
msg = "this layer has {} para, disable the operation of clipping lower, clipping upper_bound, " \
"adding noise. \n part_num={}, part_size={}" \
.format(self.para_num, self.part_num, self.part_size)
LOGGER.info(TAG, msg)

def construct(self): def construct(self):
""" """
Return the mask matrix for optimization. Return the mask matrix for optimization.


+ 5
- 1
mindarmour/privacy/sup_privacy/train/model.py View File

@@ -35,6 +35,7 @@ from mindspore.parallel._utils import _get_gradients_mean
from mindspore.parallel._utils import _get_device_num from mindspore.parallel._utils import _get_device_num
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.nn.optim import SGD
from mindarmour.utils._check_param import check_param_type from mindarmour.utils._check_param import check_param_type
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl
@@ -97,14 +98,17 @@ class SuppressModel(Model):


def __init__(self, def __init__(self,
network, network,
loss_fn,
optimizer,
**kwargs): **kwargs):


check_param_type('network', network, Cell) check_param_type('network', network, Cell)
check_param_type('optimizer', optimizer, SGD)


self.network_end = None self.network_end = None
self._train_one_step = None self._train_one_step = None


super(SuppressModel, self).__init__(network, **kwargs)
super(SuppressModel, self).__init__(network, loss_fn, optimizer, **kwargs)


def link_suppress_ctrl(self, suppress_pri_ctrl): def link_suppress_ctrl(self, suppress_pri_ctrl):
""" """


Loading…
Cancel
Save