From c63c72afb150fbbaefdfbcd679f7c6fae0a2470b Mon Sep 17 00:00:00 2001 From: itcomee Date: Mon, 22 Feb 2021 11:58:59 +0800 Subject: [PATCH] suppress based privacy model, 2021.2.22 --- .../privacy/sup_privacy/sup_ctrl/conctrl.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py index 6b66164..8a45dd9 100644 --- a/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py +++ b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py @@ -134,6 +134,7 @@ class SuppressCtrl(Cell): self.mask_all_steps = (end_epoch - start_epoch + 1)*batch_num # the amount of step contained in all suppress operation self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation self.mask_initialized = False # flag means the initialization is done + self.grad_idx_map = [] if self.lr > 0.5: msg = "learning rate should not be greater than 0.5, but got {}".format(self.lr) @@ -210,6 +211,8 @@ class SuppressCtrl(Cell): de_weight_cell.mask_able = False self.de_weight_mask_list.append(de_weight_cell) + self.grad_idx_map.append(-1) + m = 0 for layer in networks.get_parameters(expand=True): one_mask_layer = None @@ -231,6 +234,7 @@ class SuppressCtrl(Cell): de_weight_cell = DeWeightInCell(add_mask_array) de_weight_cell.mask_able = True self.de_weight_mask_list[one_mask_layer.grad_idx] = de_weight_cell + self.grad_idx_map[m] = one_mask_layer.grad_idx msg = "do mask {}, {}, {}".format(m, one_mask_layer.layer_name, one_mask_layer.grad_idx) LOGGER.info(TAG, msg) elif one_mask_layer is not None and one_mask_layer.inited: @@ -294,7 +298,10 @@ class SuppressCtrl(Cell): math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) m = 0 for layer in networks.get_parameters(expand=True): - if self.grads_mask_list[m].mask_able: + grad_idx = self.grad_idx_map[m] + if grad_idx < 0: + continue + if self.grads_mask_list[grad_idx].mask_able: weight_array = layer.data.asnumpy() weight_avg = np.mean(weight_array) weight_array_flat = weight_array.flatten() @@ -307,14 +314,14 @@ class SuppressCtrl(Cell): msg = "give up this masking .." LOGGER.info(TAG, msg) return - if self.grads_mask_list[m].min_num > 0: + if self.grads_mask_list[grad_idx].min_num > 0: sparse_weight_thd, _, actual_stop_pos = self.calc_sparse_thd(weight_array_flat_abs, - self.cur_sparse, m) + self.cur_sparse, grad_idx) else: actual_stop_pos = int(len_array * self.cur_sparse) sparse_weight_thd = weight_array_flat_abs[actual_stop_pos] - self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, m) + self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, grad_idx) msg = "{} len={}, sparse={}, current sparse thd={}, max={}, avg={}, avg_abs={} \n".format( layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd, @@ -570,8 +577,12 @@ class MaskLayerDes: Args: layer_name (str): Layer name, get the name of one layer as following: - for layer in networks.get_parameters(expand=True): - if layer.name == "conv": ... + + .. code-block:: + + for layer in networks.get_parameters(expand=True): + if layer.name == "conv": ... + grad_idx (int): Grad layer index, get mask layer's index in grad tuple.You can refer to the construct function of TrainOneStepCell in mindarmour/privacy/sup_privacy/train/model.py to get the index of some specified grad layers (print in PYNATIVE_MODE).