|
|
@@ -134,6 +134,7 @@ class SuppressCtrl(Cell): |
|
|
|
self.mask_all_steps = (end_epoch - start_epoch + 1)*batch_num # the amount of step contained in all suppress operation |
|
|
|
self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation |
|
|
|
self.mask_initialized = False # flag means the initialization is done |
|
|
|
self.grad_idx_map = [] |
|
|
|
|
|
|
|
if self.lr > 0.5: |
|
|
|
msg = "learning rate should not be greater than 0.5, but got {}".format(self.lr) |
|
|
@@ -210,6 +211,8 @@ class SuppressCtrl(Cell): |
|
|
|
de_weight_cell.mask_able = False |
|
|
|
self.de_weight_mask_list.append(de_weight_cell) |
|
|
|
|
|
|
|
self.grad_idx_map.append(-1) |
|
|
|
|
|
|
|
m = 0 |
|
|
|
for layer in networks.get_parameters(expand=True): |
|
|
|
one_mask_layer = None |
|
|
@@ -231,6 +234,7 @@ class SuppressCtrl(Cell): |
|
|
|
de_weight_cell = DeWeightInCell(add_mask_array) |
|
|
|
de_weight_cell.mask_able = True |
|
|
|
self.de_weight_mask_list[one_mask_layer.grad_idx] = de_weight_cell |
|
|
|
self.grad_idx_map[m] = one_mask_layer.grad_idx |
|
|
|
msg = "do mask {}, {}, {}".format(m, one_mask_layer.layer_name, one_mask_layer.grad_idx) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
elif one_mask_layer is not None and one_mask_layer.inited: |
|
|
@@ -294,7 +298,10 @@ class SuppressCtrl(Cell): |
|
|
|
math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) |
|
|
|
m = 0 |
|
|
|
for layer in networks.get_parameters(expand=True): |
|
|
|
if self.grads_mask_list[m].mask_able: |
|
|
|
grad_idx = self.grad_idx_map[m] |
|
|
|
if grad_idx < 0: |
|
|
|
continue |
|
|
|
if self.grads_mask_list[grad_idx].mask_able: |
|
|
|
weight_array = layer.data.asnumpy() |
|
|
|
weight_avg = np.mean(weight_array) |
|
|
|
weight_array_flat = weight_array.flatten() |
|
|
@@ -307,14 +314,14 @@ class SuppressCtrl(Cell): |
|
|
|
msg = "give up this masking .." |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
return |
|
|
|
if self.grads_mask_list[m].min_num > 0: |
|
|
|
if self.grads_mask_list[grad_idx].min_num > 0: |
|
|
|
sparse_weight_thd, _, actual_stop_pos = self.calc_sparse_thd(weight_array_flat_abs, |
|
|
|
self.cur_sparse, m) |
|
|
|
self.cur_sparse, grad_idx) |
|
|
|
else: |
|
|
|
actual_stop_pos = int(len_array * self.cur_sparse) |
|
|
|
sparse_weight_thd = weight_array_flat_abs[actual_stop_pos] |
|
|
|
|
|
|
|
self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, m) |
|
|
|
self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, grad_idx) |
|
|
|
|
|
|
|
msg = "{} len={}, sparse={}, current sparse thd={}, max={}, avg={}, avg_abs={} \n".format( |
|
|
|
layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd, |
|
|
@@ -570,8 +577,12 @@ class MaskLayerDes: |
|
|
|
|
|
|
|
Args: |
|
|
|
layer_name (str): Layer name, get the name of one layer as following: |
|
|
|
for layer in networks.get_parameters(expand=True): |
|
|
|
if layer.name == "conv": ... |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
for layer in networks.get_parameters(expand=True): |
|
|
|
if layer.name == "conv": ... |
|
|
|
|
|
|
|
grad_idx (int): Grad layer index, get mask layer's index in grad tuple.You can refer to the construct function |
|
|
|
of TrainOneStepCell in mindarmour/privacy/sup_privacy/train/model.py to get the index of some specified |
|
|
|
grad layers (print in PYNATIVE_MODE). |
|
|
|