https://arxiv.org/abs/1906.08935tags/v1.2.0-rc1
@@ -1,33 +1,54 @@ | |||||
# Application demos of privacy stealing and privacy protection | # Application demos of privacy stealing and privacy protection | ||||
## Introduction | ## Introduction | ||||
Although machine learning could obtain a generic model based on training data, it has been proved that the trained | Although machine learning could obtain a generic model based on training data, it has been proved that the trained | ||||
model may disclose the information of training data (such as the membership inference attack). Differential | |||||
privacy training | |||||
is an effective | |||||
method proposed | |||||
to overcome this problem, in which Gaussian noise is added while training. There are mainly three parts for | |||||
differential privacy(DP) training: noise-generating mechanism, DP optimizer and DP monitor. We have implemented | |||||
a novel noise-generating mechanisms: adaptive decay noise mechanism. DP | |||||
monitor is used to compute the privacy budget while training. | |||||
model may disclose the information of training data (such as the membership inference attack). | |||||
Differential privacy training is an effective method proposed to overcome this problem, in which Gaussian noise is | |||||
added while training. There are mainly three parts for differential privacy(DP) training: noise-generating | |||||
mechanism, DP optimizer and DP monitor. We have implemented a novel noise-generating mechanisms: adaptive decay | |||||
noise mechanism. DP monitor is used to compute the privacy budget while training. | |||||
Suppress Privacy training is a novel method to protect privacy distinct from the noise addition method | |||||
(such as DP), in which the negligible model parameter is removed gradually to achieve a better balance between | |||||
accuracy and privacy. | |||||
## 1. Adaptive decay DP training | ## 1. Adaptive decay DP training | ||||
With adaptive decay mechanism, the magnitude of the Gaussian noise would be decayed as the training step grows, which | With adaptive decay mechanism, the magnitude of the Gaussian noise would be decayed as the training step grows, which | ||||
resulting a stable convergence. | resulting a stable convergence. | ||||
```sh | ```sh | ||||
$ cd examples/privacy/diff_privacy | |||||
$ python lenet5_dp_ada_gaussian.py | |||||
cd examples/privacy/diff_privacy | |||||
python lenet5_dp_ada_gaussian.py | |||||
``` | ``` | ||||
## 2. Adaptive norm clip training | ## 2. Adaptive norm clip training | ||||
With adaptive norm clip mechanism, the norm clip of the gradients would be changed according to the norm values of | With adaptive norm clip mechanism, the norm clip of the gradients would be changed according to the norm values of | ||||
them, which can adjust the ratio of noise and original gradients. | them, which can adjust the ratio of noise and original gradients. | ||||
```sh | ```sh | ||||
$ cd examples/privacy/diff_privacy | |||||
$ python lenet5_dp.py | |||||
cd examples/privacy/diff_privacy | |||||
python lenet5_dp.py | |||||
``` | ``` | ||||
## 3. Membership inference evaluation | ## 3. Membership inference evaluation | ||||
By this evaluation method, we could judge whether a sample is belongs to training dataset or not. | By this evaluation method, we could judge whether a sample is belongs to training dataset or not. | ||||
```sh | |||||
cd examples/privacy/membership_inference_attack | |||||
python train.py --data_path home_path_to_cifar100 --ckpt_path ./ | |||||
python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt | |||||
``` | |||||
## 4. suppress privacy training | |||||
With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected | |||||
layers) are set to zero as the training step grows, which can | |||||
achieve a better balance between accuracy and privacy | |||||
```sh | ```sh | ||||
$ cd examples/privacy/membership_inference_attack | |||||
$ python train.py --data_path home_path_to_cifar100 --ckpt_path ./ | |||||
$ python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt | |||||
cd examples/privacy/sup_privacy | |||||
python sup_privacy.py | |||||
``` | ``` |
@@ -0,0 +1,154 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Training example of suppress-based privacy. | |||||
""" | |||||
import os | |||||
import mindspore.nn as nn | |||||
from mindspore import context | |||||
from mindspore.train.callback import ModelCheckpoint | |||||
from mindspore.train.callback import CheckpointConfig | |||||
from mindspore.train.callback import LossMonitor | |||||
from mindspore.nn.metrics import Accuracy | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
import mindspore.dataset as ds | |||||
import mindspore.dataset.vision.c_transforms as CV | |||||
import mindspore.dataset.transforms.c_transforms as C | |||||
from mindspore.dataset.vision.utils import Inter | |||||
import mindspore.common.dtype as mstype | |||||
from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
from sup_privacy_config import mnist_cfg as cfg | |||||
from mindarmour.privacy.sup_privacy import SuppressModel | |||||
from mindarmour.privacy.sup_privacy import SuppressMasker | |||||
from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory | |||||
from mindarmour.privacy.sup_privacy import MaskLayerDes | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
LOGGER.set_level('INFO') | |||||
TAG = 'Lenet5_Suppress_train' | |||||
def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, samples=None, num_parallel_workers=1, sparse=True): | |||||
""" | |||||
create dataset for training or testing | |||||
""" | |||||
# define dataset | |||||
ds1 = ds.MnistDataset(data_path, num_samples=samples) | |||||
# define operation parameters | |||||
resize_height, resize_width = 32, 32 | |||||
rescale = 1.0 / 255.0 | |||||
shift = 0.0 | |||||
# define map operations | |||||
resize_op = CV.Resize((resize_height, resize_width), | |||||
interpolation=Inter.LINEAR) | |||||
rescale_op = CV.Rescale(rescale, shift) | |||||
hwc2chw_op = CV.HWC2CHW() | |||||
type_cast_op = C.TypeCast(mstype.int32) | |||||
# apply map operations on images | |||||
if not sparse: | |||||
one_hot_enco = C.OneHot(10) | |||||
ds1 = ds1.map(input_columns="label", operations=one_hot_enco, num_parallel_workers=num_parallel_workers) | |||||
type_cast_op = C.TypeCast(mstype.float32) | |||||
ds1 = ds1.map(input_columns="label", operations=type_cast_op, | |||||
num_parallel_workers=num_parallel_workers) | |||||
ds1 = ds1.map(input_columns="image", operations=resize_op, | |||||
num_parallel_workers=num_parallel_workers) | |||||
ds1 = ds1.map(input_columns="image", operations=rescale_op, | |||||
num_parallel_workers=num_parallel_workers) | |||||
ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, | |||||
num_parallel_workers=num_parallel_workers) | |||||
# apply DatasetOps | |||||
buffer_size = 10000 | |||||
ds1 = ds1.shuffle(buffer_size=buffer_size) | |||||
ds1 = ds1.batch(batch_size, drop_remainder=True) | |||||
ds1 = ds1.repeat(repeat_size) | |||||
return ds1 | |||||
def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, mask_times=1000, | |||||
sparse_thd=0.90, sparse_start=0.0, masklayers=None): | |||||
""" | |||||
local train by suppress-based privacy | |||||
""" | |||||
networks_l5 = LeNet5() | |||||
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
end_epoch=epoch_size, | |||||
batch_num=(int)(samples/cfg.batch_size), | |||||
start_epoch=start_epoch, | |||||
mask_times=mask_times, | |||||
networks=networks_l5, | |||||
lr=lr, | |||||
sparse_end=sparse_thd, | |||||
sparse_start=sparse_start, | |||||
mask_layers=masklayers) | |||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
net_opt = nn.SGD(networks_l5.trainable_params(), lr) | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), | |||||
keep_checkpoint_max=10) | |||||
# Create the SuppressModel model for training. | |||||
model_instance = SuppressModel(network=networks_l5, | |||||
loss_fn=net_loss, | |||||
optimizer=net_opt, | |||||
metrics={"Accuracy": Accuracy()}) | |||||
model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
# Create a Masker for Suppress training. The function of the Masker is to | |||||
# enforce suppress operation while training. | |||||
suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance) | |||||
mnist_path = "./MNIST_unzip/" #"../../MNIST_unzip/" | |||||
ds_train = generate_mnist_dataset(os.path.join(mnist_path, "train"), | |||||
batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
directory="./trained_ckpt_file/", | |||||
config=config_ck) | |||||
print("============== Starting SUPP Training ==============") | |||||
model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
dataset_sink_mode=False) | |||||
print("============== Starting SUPP Testing ==============") | |||||
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
param_dict = load_checkpoint(ckpt_file_name) | |||||
load_param_into_net(networks_l5, param_dict) | |||||
ds_eval = generate_mnist_dataset(os.path.join(mnist_path, 'test'), | |||||
batch_size=cfg.batch_size) | |||||
acc = model_instance.eval(ds_eval, dataset_sink_mode=False) | |||||
print("============== SUPP Accuracy: %s ==============", acc) | |||||
if __name__ == "__main__": | |||||
# This configure can run in pynative mode | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | |||||
masklayers_lenet5 = [] # determine which layer should be masked | |||||
masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, True, 10)) | |||||
masklayers_lenet5.append(MaskLayerDes("conv2.weight", False, True, 150)) | |||||
masklayers_lenet5.append(MaskLayerDes("fc1.weight", True, False, -1)) | |||||
masklayers_lenet5.append(MaskLayerDes("fc2.weight", True, False, -1)) | |||||
masklayers_lenet5.append(MaskLayerDes("fc3.weight", True, False, 50)) | |||||
# do suppreess privacy train, with stronger privacy protection and better performance than Differential Privacy | |||||
mnist_suppress_train(10, 3, 0.10, 60000, 1000, 0.95, 0.0, masklayers=masklayers_lenet5) # used |
@@ -0,0 +1,32 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
""" | |||||
network config setting, will be used in sup_privacy.py | |||||
""" | |||||
from easydict import EasyDict as edict | |||||
mnist_cfg = edict({ | |||||
'num_classes': 10, # the number of classes of model's output | |||||
'epoch_size': 1, # training epochs | |||||
'batch_size': 32, # batch size for training | |||||
'image_height': 32, # the height of training samples | |||||
'image_width': 32, # the width of training samples | |||||
'save_checkpoint_steps': 1875, # the interval steps for saving checkpoint file of the model | |||||
'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved | |||||
'device_target': 'Ascend', # device used | |||||
'data_path': './MNIST_unzip', # the path of training and testing data set | |||||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||||
}) |
@@ -0,0 +1,27 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
This module provides Suppress Privacy feature to protect user privacy. | |||||
""" | |||||
from .mask_monitor.masker import SuppressMasker | |||||
from .train.model import SuppressModel | |||||
from .sup_ctrl.conctrl import SuppressPrivacyFactory | |||||
from .sup_ctrl.conctrl import SuppressCtrl | |||||
from .sup_ctrl.conctrl import MaskLayerDes | |||||
__all__ = ['SuppressMasker', | |||||
'SuppressModel', | |||||
'SuppressPrivacyFactory', | |||||
'SuppressCtrl', | |||||
'MaskLayerDes'] |
@@ -0,0 +1,98 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Masker module of suppress-based privacy.. | |||||
""" | |||||
from mindspore.train.callback import Callback | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.utils._check_param import check_param_type | |||||
from mindarmour.privacy.sup_privacy.train.model import SuppressModel | |||||
from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'suppress masker' | |||||
class SuppressMasker(Callback): | |||||
""" | |||||
Args: | |||||
args (Union[int, float, numpy.ndarray, list, str]): Parameters | |||||
used for creating a suppress privacy monitor. | |||||
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword | |||||
parameters used for creating a suppress privacy monitor. | |||||
model (SuppressModel): SuppressModel instance. | |||||
suppress_ctrl (SuppressCtrl): SuppressCtrl instance. | |||||
Examples: | |||||
networks_l5 = LeNet5() | |||||
masklayers = [] | |||||
masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
end_epoch=10, | |||||
batch_num=(int)(10000/cfg.batch_size), | |||||
start_epoch=3, | |||||
mask_times=100, | |||||
networks=networks_l5, | |||||
lr=lr, | |||||
sparse_end=0.90, | |||||
sparse_start=0.0, | |||||
mask_layers=masklayers) | |||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) | |||||
model_instance = SuppressModel(network=networks_l5, | |||||
loss_fn=net_loss, | |||||
optimizer=net_opt, | |||||
metrics={"Accuracy": Accuracy()}) | |||||
model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
directory="./trained_ckpt_file/", | |||||
config=config_ck) | |||||
model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
dataset_sink_mode=False) | |||||
""" | |||||
def __init__(self, model=None, suppress_ctrl=None): | |||||
super(SuppressMasker, self).__init__() | |||||
self._model = check_param_type('model', model, SuppressModel) | |||||
self._suppress_ctrl = check_param_type('suppress_ctrl', suppress_ctrl, SuppressCtrl) | |||||
def step_end(self, run_context): | |||||
""" | |||||
Update mask matrix tensor used for SuppressModel instance. | |||||
Args: | |||||
run_context (RunContext): Include some information of the model. | |||||
""" | |||||
cb_params = run_context.original_args() | |||||
cur_step = cb_params.cur_step_num | |||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
if self._suppress_ctrl is not None and self._model.network_end is not None: | |||||
self._suppress_ctrl.update_status(cb_params.cur_epoch_num, cur_step, cur_step_in_epoch) | |||||
if not self._suppress_ctrl.mask_initialized: | |||||
raise ValueError("Not initialize network!") | |||||
if self._suppress_ctrl.to_do_mask: | |||||
self._suppress_ctrl.update_mask(self._suppress_ctrl.networks, cur_step) | |||||
LOGGER.info(TAG, "suppress update") | |||||
elif not self._suppress_ctrl.to_do_mask and self._suppress_ctrl.mask_started: | |||||
self._suppress_ctrl.reset_zeros() | |||||
if cur_step_in_epoch % 100 == 1: | |||||
self._suppress_ctrl.calc_theoretical_sparse_for_conv() | |||||
_, _, _ = self._suppress_ctrl.calc_actual_sparse_for_conv( | |||||
self._suppress_ctrl.networks) |
@@ -0,0 +1,640 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
control function of suppress-based privacy. | |||||
""" | |||||
import math | |||||
import numpy as np | |||||
from mindspore import Tensor | |||||
from mindspore.ops import operations as P | |||||
from mindspore.common import dtype as mstype | |||||
from mindspore.nn import Cell | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.utils._check_param import check_int_positive, check_value_positive, \ | |||||
check_value_non_negative, check_param_type | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Suppression training.' | |||||
class SuppressPrivacyFactory: | |||||
""" Factory class of SuppressCtrl mechanisms""" | |||||
def __init__(self): | |||||
pass | |||||
@staticmethod | |||||
def create(policy="local_train", end_epoch=10, batch_num=2, start_epoch=3, mask_times=100, networks=None, | |||||
lr=0.05, sparse_end=0.60, sparse_start=0.0, mask_layers=None): | |||||
""" | |||||
Args: | |||||
policy (str): Training policy for suppress privacy training. "local_train" means local training. | |||||
end_epoch (int): The last epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . | |||||
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size . | |||||
start_epoch (int): The first epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . | |||||
mask_times (int): The num of suppress operations. | |||||
networks (Cell): The training network. | |||||
lr (float): Learning rate. | |||||
sparse_end (float): The sparsity to reach, 0.0 <= sparse_start < sparse_end < 1.0 . | |||||
sparse_start (float): The sparsity to start, 0.0 <= sparse_start < sparse_end < 1.0 . | |||||
mask_layers (list): Description of the training network layers that need to be suppressed. | |||||
Returns: | |||||
SuppressCtrl, class of Suppress Privavy Mechanism. | |||||
Examples: | |||||
networks_l5 = LeNet5() | |||||
masklayers = [] | |||||
masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
end_epoch=10, | |||||
batch_num=(int)(10000/cfg.batch_size), | |||||
start_epoch=3, | |||||
mask_times=100, | |||||
networks=networks_l5, | |||||
lr=lr, | |||||
sparse_end=0.90, | |||||
sparse_start=0.0, | |||||
mask_layers=masklayers) | |||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) | |||||
model_instance = SuppressModel(network=networks_l5, | |||||
loss_fn=net_loss, | |||||
optimizer=net_opt, | |||||
metrics={"Accuracy": Accuracy()}) | |||||
model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
directory="./trained_ckpt_file/", | |||||
config=config_ck) | |||||
model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
dataset_sink_mode=False) | |||||
""" | |||||
if policy == "local_train": | |||||
return SuppressCtrl(networks, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end, | |||||
sparse_start, mask_layers) | |||||
msg = "Only local training is supported now, federal training will be supported " \ | |||||
"in the future. But got {}.".format(policy) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
class SuppressCtrl(Cell): | |||||
""" | |||||
Args: | |||||
networks (Cell): The training network. | |||||
end_epoch (int): The last epoch in suppress operations. | |||||
batch_num (int): The num of grad operation in an epoch. | |||||
mask_start_epoch (int): The first epoch in suppress operations. | |||||
mask_times (int): The num of suppress operations. | |||||
lr (Union[float, int]): Learning rate. | |||||
sparse_end (Union[float, int]): The sparsity to reach. | |||||
sparse_start (float): The sparsity to start. | |||||
mask_layers (list): Description of those layers that need to be suppressed. | |||||
""" | |||||
def __init__(self, networks, end_epoch, batch_num, mask_start_epoch=3, mask_times=500, lr=0.05, | |||||
sparse_end=0.60, | |||||
sparse_start=0.0, | |||||
mask_layers=None): | |||||
super(SuppressCtrl, self).__init__() | |||||
self.networks = check_param_type('networks', networks, Cell) | |||||
self.mask_end_epoch = check_int_positive('end_epoch', end_epoch) | |||||
self.batch_num = check_int_positive('batch_num', batch_num) | |||||
self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch) | |||||
self.mask_times = check_int_positive('mask_times', mask_times) | |||||
self.lr = check_value_positive('lr', lr) | |||||
self.sparse_end = check_value_non_negative('sparse_end', sparse_end) | |||||
self.sparse_start = check_value_non_negative('sparse_start', sparse_start) | |||||
self.mask_layers = check_param_type('mask_layers', mask_layers, list) | |||||
self.weight_lower_bound = 0.005 # all network weight will be larger than this value | |||||
self.sparse_vibra = 0.02 # the sparsity may have certain range of variations | |||||
self.sparse_valid_max_weight = 0.20 # if max network weight is less than this value, suppress operation stop temporarily | |||||
self.add_noise_thd = 0.50 # if network weight is more than this value, noise is forced | |||||
self.noise_volume = 0.01 # noise volume 0.01 | |||||
self.base_ground_thd = 0.0000001 # if network weight is less than this value, will be considered as 0 | |||||
self.model = None # SuppressModel instance | |||||
self.grads_mask_list = [] # list for Grad Mask Matrix tensor | |||||
self.de_weight_mask_list = [] # list for weight Mask Matrix tensor | |||||
self.to_do_mask = False # the flag means suppress operation is toggled immediately | |||||
self.mask_started = False # the flag means suppress operation has been started | |||||
self.mask_start_step = 0 # suppress operation is actually started at this step | |||||
self.mask_prev_step = 0 # previous suppress operation is done at this step | |||||
self.cur_sparse = 0.0 # current sparsity to which one suppress will get | |||||
self.mask_all_steps = (self.mask_end_epoch-mask_start_epoch+1)*batch_num # the amount of step contained in all suppress operation | |||||
self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation | |||||
self.mask_initialized = False # flag means the initialization is done | |||||
if self.mask_start_epoch > self.mask_end_epoch: | |||||
msg = "start_epoch error: {}".format(self.mask_start_epoch) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if self.mask_end_epoch > 100: | |||||
msg = "end_epoch error: {}".format(self.mask_end_epoch) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if self.mask_step_interval < 0: | |||||
msg = "step_interval error: {}".format(self.mask_step_interval) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if self.sparse_end > 1.00 or self.sparse_end <= 0: | |||||
msg = "sparse_end error: {}".format(self.sparse_end) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if self.sparse_start >= self.sparse_end: | |||||
msg = "sparse_start error: {}".format(self.sparse_start) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if mask_layers is not None: | |||||
mask_layer_id = 0 | |||||
for one_mask_layer in mask_layers: | |||||
if not isinstance(one_mask_layer, MaskLayerDes): | |||||
msg = "mask_layer instance error!" | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
layer_name = one_mask_layer.layer_name | |||||
mask_layer_id2 = 0 | |||||
for one_mask_layer_2 in mask_layers: | |||||
if mask_layer_id != mask_layer_id2 and layer_name in one_mask_layer_2.layer_name: | |||||
msg = "mask_layers repeat item : {} in {} and {}".format(layer_name, | |||||
mask_layer_id, | |||||
mask_layer_id2) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
mask_layer_id2 = mask_layer_id2 + 1 | |||||
mask_layer_id = mask_layer_id + 1 | |||||
if networks is not None: | |||||
m = 0 | |||||
for layer in networks.get_parameters(expand=True): | |||||
one_mask_layer = None | |||||
if mask_layers is not None: | |||||
one_mask_layer = get_one_mask_layer(mask_layers, layer.name) | |||||
if one_mask_layer is not None and not one_mask_layer.inited: | |||||
one_mask_layer.inited = True | |||||
shape = P.Shape()(layer) | |||||
mul_mask_array = np.ones(shape, dtype=np.float32) | |||||
grad_mask_cell = GradMaskInCell(mul_mask_array, | |||||
one_mask_layer.is_add_noise, | |||||
one_mask_layer.is_lower_clip, | |||||
one_mask_layer.min_num, | |||||
one_mask_layer.upper_bound) | |||||
grad_mask_cell.mask_able = True | |||||
self.grads_mask_list.append(grad_mask_cell) | |||||
add_mask_array = np.zeros(shape, dtype=np.float32) | |||||
de_weight_cell = DeWeightInCell(add_mask_array) | |||||
de_weight_cell.mask_able = True | |||||
self.de_weight_mask_list.append(de_weight_cell) | |||||
msg = "do mask {}, {}".format(m, one_mask_layer.layer_name) | |||||
LOGGER.info(TAG, msg) | |||||
elif one_mask_layer is not None and one_mask_layer.inited: | |||||
msg = "repeated match masked setting {}=>{}.".format(one_mask_layer.layer_name, layer.name) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
else: | |||||
shape = np.shape([1]) | |||||
mul_mask_array = np.ones(shape, dtype=np.float32) | |||||
grad_mask_cell = GradMaskInCell(mul_mask_array, False, False, -1) | |||||
grad_mask_cell.mask_able = False | |||||
self.grads_mask_list.append(grad_mask_cell) | |||||
add_mask_array = np.zeros(shape, dtype=np.float32) | |||||
de_weight_cell = DeWeightInCell(add_mask_array) | |||||
de_weight_cell.mask_able = False | |||||
self.de_weight_mask_list.append(de_weight_cell) | |||||
m += 1 | |||||
self.mask_initialized = True | |||||
msg = "init SuppressCtrl by networks" | |||||
LOGGER.info(TAG, msg) | |||||
msg = "complete init mask for lenet5.step_interval: {}".format(self.mask_step_interval) | |||||
LOGGER.info(TAG, msg) | |||||
for one_mask_layer in mask_layers: | |||||
if not one_mask_layer.inited: | |||||
msg = "can't match this mask layer: {} ".format(one_mask_layer.layer_name) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
def update_status(self, cur_epoch, cur_step, cur_step_in_epoch): | |||||
""" | |||||
Update the suppress operation status. | |||||
Args: | |||||
cur_epoch (int): Current epoch of the whole training process. | |||||
cur_step (int): Current step of the whole training process. | |||||
cur_step_in_epoch (int): Current step of the current epoch. | |||||
""" | |||||
if not self.mask_initialized: | |||||
self.mask_started = False | |||||
elif (self.mask_start_epoch <= cur_epoch <= self.mask_end_epoch) or self.mask_started: | |||||
if not self.mask_started: | |||||
self.mask_started = True | |||||
self.mask_start_step = cur_step | |||||
if cur_step >= (self.mask_prev_step + self.mask_step_interval): | |||||
self.mask_prev_step = cur_step | |||||
self.to_do_mask = True | |||||
# execute the last suppression operation | |||||
elif cur_epoch == self.mask_end_epoch and cur_step_in_epoch == self.batch_num-2: | |||||
self.mask_prev_step = cur_step | |||||
self.to_do_mask = True | |||||
else: | |||||
self.to_do_mask = False | |||||
else: | |||||
self.to_do_mask = False | |||||
self.mask_started = False | |||||
def update_mask(self, networks, cur_step): | |||||
""" | |||||
Update add mask arrays and multiply mask arrays of network layers. | |||||
Args: | |||||
networks (Cell): The training network. | |||||
cur_step (int): Current epoch of the whole training process. | |||||
""" | |||||
if self.sparse_end <= 0.0: | |||||
return | |||||
self.cur_sparse = self.sparse_end +\ | |||||
(self.sparse_start - self.sparse_end)*\ | |||||
math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) | |||||
m = 0 | |||||
for layer in networks.get_parameters(expand=True): | |||||
if self.grads_mask_list[m].mask_able: | |||||
weight_array = layer.data.asnumpy() | |||||
weight_avg = np.mean(weight_array) | |||||
weight_array_flat = weight_array.flatten() | |||||
weight_array_flat_abs = np.abs(weight_array_flat) | |||||
weight_abs_avg = np.mean(weight_array_flat_abs) | |||||
weight_array_flat_abs.sort() | |||||
len_array = weight_array.size | |||||
weight_abs_max = np.max(weight_array_flat_abs) | |||||
if m == 0 and weight_abs_max < self.sparse_valid_max_weight: | |||||
msg = "give up this masking .." | |||||
LOGGER.info(TAG, msg) | |||||
return | |||||
if self.grads_mask_list[m].min_num > 0: | |||||
sparse_weight_thd, _, actual_stop_pos = self.calc_sparse_thd(weight_array_flat_abs, | |||||
self.cur_sparse, m) | |||||
else: | |||||
actual_stop_pos = int(len_array * self.cur_sparse) | |||||
sparse_weight_thd = weight_array_flat_abs[actual_stop_pos] | |||||
self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, m) | |||||
msg = "{} len={}, sparse={}, current sparse thd={}, max={}, avg={}, avg_abs={} \n".format( | |||||
layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd, | |||||
weight_abs_max, weight_avg, weight_abs_avg) | |||||
LOGGER.info(TAG, msg) | |||||
m = m + 1 | |||||
def update_mask_layer(self, weight_array_flat, sparse_weight_thd, sparse_stop_pos, weight_abs_max, layer_index): | |||||
""" | |||||
Update add mask arrays and multiply mask arrays of one single layer. | |||||
Args: | |||||
weight_array (numpy.ndarray): The weight array of layer's parameters. | |||||
sparse_weight_thd (float): The weight threshold of sparse operation. | |||||
sparse_stop_pos (int): The maximum number of elements to be suppressed. | |||||
weight_abs_max (float): The maximum absolute value of weights. | |||||
layer_index (int): The index of target layer. | |||||
""" | |||||
grad_mask_cell = self.grads_mask_list[layer_index] | |||||
mul_mask_array_flat = grad_mask_cell.mul_mask_array_flat | |||||
de_weight_cell = self.de_weight_mask_list[layer_index] | |||||
add_mask_array_flat = de_weight_cell.add_mask_array_flat | |||||
min_num = grad_mask_cell.min_num | |||||
is_add_noise = grad_mask_cell.is_add_noise | |||||
is_lower_clip = grad_mask_cell.is_lower_clip | |||||
upper_bound = grad_mask_cell.upper_bound | |||||
if not self.grads_mask_list[layer_index].mask_able: | |||||
return | |||||
m = 0 | |||||
n = 0 | |||||
p = 0 | |||||
q = 0 | |||||
# add noise on weights if not masking or clipping. | |||||
weight_noise_bound = min(self.add_noise_thd, max(self.noise_volume*10, weight_abs_max*0.75)) | |||||
for i in range(0, weight_array_flat.size): | |||||
if abs(weight_array_flat[i]) <= sparse_weight_thd: | |||||
if m < weight_array_flat.size - min_num and m < sparse_stop_pos: | |||||
# to mask | |||||
mul_mask_array_flat[i] = 0.0 | |||||
add_mask_array_flat[i] = weight_array_flat[i] / self.lr | |||||
m = m + 1 | |||||
else: | |||||
# not mask | |||||
if weight_array_flat[i] > 0.0: | |||||
add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr | |||||
else: | |||||
add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr | |||||
p = p + 1 | |||||
elif is_lower_clip and abs(weight_array_flat[i]) <= \ | |||||
self.weight_lower_bound and sparse_weight_thd > self.weight_lower_bound*0.5: | |||||
# not mask | |||||
mul_mask_array_flat[i] = 1.0 | |||||
if weight_array_flat[i] > 0.0: | |||||
add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr | |||||
else: | |||||
add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr | |||||
p = p + 1 | |||||
elif abs(weight_array_flat[i]) > upper_bound: | |||||
mul_mask_array_flat[i] = 1.0 | |||||
if weight_array_flat[i] > 0.0: | |||||
add_mask_array_flat[i] = (weight_array_flat[i] - upper_bound) / self.lr | |||||
else: | |||||
add_mask_array_flat[i] = (weight_array_flat[i] + upper_bound) / self.lr | |||||
n = n + 1 | |||||
else: | |||||
# not mask | |||||
mul_mask_array_flat[i] = 1.0 | |||||
if is_add_noise and abs(weight_array_flat[i]) > weight_noise_bound > 0.0: | |||||
# add noise | |||||
add_mask_array_flat[i] = np.random.uniform(-self.noise_volume, self.noise_volume) / self.lr | |||||
q = q + 1 | |||||
else: | |||||
add_mask_array_flat[i] = 0.0 | |||||
grad_mask_cell.update() | |||||
de_weight_cell.update() | |||||
msg = "Dimension of mask tensor is {}D, which located in the {}-th layer of the network. \n The number of " \ | |||||
"suppressed elements, max-clip elements, min-clip elements and noised elements are {}, {}, {}, {}"\ | |||||
.format(len(grad_mask_cell.mul_mask_array_shape), layer_index, m, n, p, q) | |||||
LOGGER.info(TAG, msg) | |||||
def calc_sparse_thd(self, array_flat, sparse_value, layer_index): | |||||
""" | |||||
Calculate the suppression threshold of one weight array. | |||||
Args: | |||||
array_flat (numpy.ndarray): The flattened weight array. | |||||
sparse_value (float): The target sparse value of weight array. | |||||
Returns: | |||||
- float, the sparse threshold of this array. | |||||
- int, the number of weight elements to be suppressed. | |||||
- int, the larger number of weight elements to be suppressed. | |||||
""" | |||||
size = len(array_flat) | |||||
sparse_max_thd = 1.0 - min(self.grads_mask_list[layer_index].min_num, size) / size | |||||
pos = int(size*min(sparse_max_thd, sparse_value)) | |||||
thd = array_flat[pos] | |||||
farther_stop_pos = int(size*min(sparse_max_thd, max(0, sparse_value + self.sparse_vibra / 2.0))) | |||||
return thd, pos, farther_stop_pos | |||||
def reset_zeros(self): | |||||
""" | |||||
Set add mask arrays to be zero. | |||||
""" | |||||
for de_weight_cell in self.de_weight_mask_list: | |||||
de_weight_cell.reset_zeros() | |||||
def calc_theoretical_sparse_for_conv(self): | |||||
""" | |||||
Compute actually sparsity of mask matrix for conv1 layer and conv2 layer. | |||||
""" | |||||
array_mul_mask_flat_conv1 = self.grads_mask_list[0].mul_mask_array_flat | |||||
array_mul_mask_flat_conv2 = self.grads_mask_list[1].mul_mask_array_flat | |||||
sparse = 0.0 | |||||
sparse_value_1 = 0.0 | |||||
sparse_value_2 = 0.0 | |||||
full = 0.0 | |||||
full_conv1 = 0.0 | |||||
full_conv2 = 0.0 | |||||
for i in range(0, array_mul_mask_flat_conv1.size): | |||||
full += 1.0 | |||||
full_conv1 += 1.0 | |||||
if array_mul_mask_flat_conv1[i] <= 0.0: | |||||
sparse += 1.0 | |||||
sparse_value_1 += 1.0 | |||||
for i in range(0, array_mul_mask_flat_conv2.size): | |||||
full = full + 1.0 | |||||
full_conv2 = full_conv2 + 1.0 | |||||
if array_mul_mask_flat_conv2[i] <= 0.0: | |||||
sparse = sparse + 1.0 | |||||
sparse_value_2 += 1.0 | |||||
sparse = sparse/full | |||||
sparse_value_1 = sparse_value_1/full_conv1 | |||||
sparse_value_2 = sparse_value_2/full_conv2 | |||||
msg = "conv sparse mask={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2) | |||||
LOGGER.info(TAG, msg) | |||||
return sparse, sparse_value_1, sparse_value_2 | |||||
def calc_actual_sparse_for_conv(self, networks): | |||||
""" | |||||
Compute actually sparsity of network for conv1 layer and conv2 layer. | |||||
Args: | |||||
networks (Cell): The training network. | |||||
""" | |||||
sparse = 0.0 | |||||
sparse_value_1 = 0.0 | |||||
sparse_value_2 = 0.0 | |||||
full = 0.0 | |||||
full_conv1 = 0.0 | |||||
full_conv2 = 0.0 | |||||
array_cur_conv1 = np.ones(np.shape([1]), dtype=np.float32) | |||||
array_cur_conv2 = np.ones(np.shape([1]), dtype=np.float32) | |||||
for layer in networks.get_parameters(expand=True): | |||||
if "conv1.weight" in layer.name: | |||||
array_cur_conv1 = layer.data.asnumpy() | |||||
if "conv2.weight" in layer.name: | |||||
array_cur_conv2 = layer.data.asnumpy() | |||||
array_mul_mask_flat_conv1 = array_cur_conv1.flatten() | |||||
array_mul_mask_flat_conv2 = array_cur_conv2.flatten() | |||||
for i in range(0, array_mul_mask_flat_conv1.size): | |||||
full += 1.0 | |||||
full_conv1 += 1.0 | |||||
if abs(array_mul_mask_flat_conv1[i]) <= self.base_ground_thd: | |||||
sparse += 1.0 | |||||
sparse_value_1 += 1.0 | |||||
for i in range(0, array_mul_mask_flat_conv2.size): | |||||
full = full + 1.0 | |||||
full_conv2 = full_conv2 + 1.0 | |||||
if abs(array_mul_mask_flat_conv2[i]) <= self.base_ground_thd: | |||||
sparse = sparse + 1.0 | |||||
sparse_value_2 += 1.0 | |||||
sparse = sparse / full | |||||
sparse_value_1 = sparse_value_1 / full_conv1 | |||||
sparse_value_2 = sparse_value_2 / full_conv2 | |||||
msg = "conv sparse fact={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2) | |||||
LOGGER.info(TAG, msg) | |||||
return sparse, sparse_value_1, sparse_value_2 | |||||
def calc_actual_sparse_for_fc1(self, networks): | |||||
self.calc_actual_sparse_for_layer(networks, "fc1.weight") | |||||
def calc_actual_sparse_for_layer(self, networks, layer_name): | |||||
""" | |||||
Compute actually sparsity of one network layer | |||||
Args: | |||||
networks (Cell): The training network. | |||||
layer_name (str): The name of target layer. | |||||
""" | |||||
check_param_type('networks', networks, Cell) | |||||
check_param_type('layer_name', layer_name, str) | |||||
sparse = 0.0 | |||||
full = 0.0 | |||||
array_cur = None | |||||
for layer in networks.get_parameters(expand=True): | |||||
if layer_name in layer.name: | |||||
array_cur = layer.data.asnumpy() | |||||
if array_cur is None: | |||||
msg = "no such layer to calc sparse: {} ".format(layer_name) | |||||
LOGGER.info(TAG, msg) | |||||
return | |||||
array_cur_flat = array_cur.flatten() | |||||
for i in range(0, array_cur_flat.size): | |||||
full += 1.0 | |||||
if abs(array_cur_flat[i]) <= self.base_ground_thd: | |||||
sparse += 1.0 | |||||
sparse = sparse / full | |||||
msg = "{} sparse fact={} ".format(layer_name, sparse) | |||||
LOGGER.info(TAG, msg) | |||||
def get_one_mask_layer(mask_layers, layer_name): | |||||
""" | |||||
Returns the layer definitions that need to be suppressed. | |||||
Args: | |||||
mask_layers (list): The layers that need to be suppressed. | |||||
layer_name (str): The name of target layer. | |||||
Returns: | |||||
Union[MaskLayerDes, None], the layer definitions that need to be suppressed. | |||||
""" | |||||
for each_mask_layer in mask_layers: | |||||
if each_mask_layer.layer_name in layer_name: | |||||
return each_mask_layer | |||||
return None | |||||
class MaskLayerDes: | |||||
""" | |||||
Describe the layer that need to be suppressed. | |||||
Args: | |||||
layer_name (str): Layer name, get the name of one layer as following: | |||||
for layer in networks.get_parameters(expand=True): | |||||
if layer.name == "conv": ... | |||||
is_add_noise (bool): If True, the weight of this layer can add noise. | |||||
If False, the weight of this layer can not add noise. | |||||
is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value. | |||||
If False, the weights of this layer won't be clipped. | |||||
min_num (int): The number of weights left that not be suppressed, which need to be greater than 0. | |||||
upper_bound (float): max value of weight in this layer, default value is 1.20 . | |||||
""" | |||||
def __init__(self, layer_name, is_add_noise, is_lower_clip, min_num, upper_bound=1.20): | |||||
self.layer_name = check_param_type('layer_name', layer_name, str) | |||||
self.is_add_noise = check_param_type('is_add_noise', is_add_noise, bool) | |||||
self.is_lower_clip = check_param_type('is_lower_clip', is_lower_clip, bool) | |||||
self.min_num = check_param_type('min_num', min_num, int) | |||||
self.upper_bound = check_value_positive('upper_bound', upper_bound) | |||||
self.inited = False | |||||
class GradMaskInCell(Cell): | |||||
""" | |||||
Define the mask matrix for gradients masking. | |||||
Args: | |||||
array (numpy.ndarray): The mask array. | |||||
is_add_noise (bool): If True, the weight of this layer can add noise. | |||||
If False, the weight of this layer can not add noise. | |||||
is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value. | |||||
If False, the weights of this layer won't be clipped. | |||||
min_num (int): The number of weights left that not be suppressed, which need to be greater than 0. | |||||
upper_bound (float): max value of weight in this layer, default value is 1.20 | |||||
""" | |||||
def __init__(self, array, is_add_noise, is_lower_clip, min_num, upper_bound=1.20): | |||||
super(GradMaskInCell, self).__init__() | |||||
self.mul_mask_array_shape = array.shape | |||||
mul_mask_array = array.copy() | |||||
self.mul_mask_array_flat = mul_mask_array.flatten() | |||||
self.mul_mask_tensor = Tensor(array, mstype.float32) | |||||
self.mask_able = False | |||||
self.is_add_noise = is_add_noise | |||||
self.is_lower_clip = is_lower_clip | |||||
self.min_num = min_num | |||||
self.upper_bound = check_value_positive('upper_bound', upper_bound) | |||||
def construct(self): | |||||
""" | |||||
Return the mask matrix for optimization. | |||||
""" | |||||
return self.mask_able, self.mul_mask_tensor | |||||
def update(self): | |||||
""" | |||||
Update the mask tensor. | |||||
""" | |||||
self.mul_mask_tensor = Tensor(self.mul_mask_array_flat.reshape(self.mul_mask_array_shape), mstype.float32) | |||||
class DeWeightInCell(Cell): | |||||
""" | |||||
Define the mask matrix for de-weight masking. | |||||
Args: | |||||
array (numpy.ndarray): The mask array. | |||||
""" | |||||
def __init__(self, array): | |||||
super(DeWeightInCell, self).__init__() | |||||
self.add_mask_array_shape = array.shape | |||||
add_mask_array = array.copy() | |||||
self.add_mask_array_flat = add_mask_array.flatten() | |||||
self.add_mask_tensor = Tensor(array, mstype.float32) | |||||
self.mask_able = False | |||||
self.zero_mask_tensor = Tensor(np.zeros(array.shape, np.float32), mstype.float32) | |||||
self.just_update = -1.0 | |||||
def construct(self): | |||||
""" | |||||
Return the mask matrix for optimization. | |||||
""" | |||||
if self.just_update > 0.0: | |||||
return self.mask_able, self.add_mask_tensor | |||||
return self.mask_able, self.zero_mask_tensor | |||||
def update(self): | |||||
""" | |||||
Update the mask tensor. | |||||
""" | |||||
self.just_update = 1.0 | |||||
self.add_mask_tensor = Tensor(self.add_mask_array_flat.reshape(self.add_mask_array_shape), mstype.float32) | |||||
def reset_zeros(self): | |||||
""" | |||||
Make the de-weight operation expired. | |||||
""" | |||||
self.just_update = -1.0 |
@@ -0,0 +1,325 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
suppress-basd privacy model. | |||||
""" | |||||
from easydict import EasyDict as edict | |||||
from mindspore.train.model import Model | |||||
from mindspore._checkparam import Validator as validator | |||||
from mindspore._checkparam import Rel | |||||
from mindspore.train.amp import _config_level | |||||
from mindspore.common import dtype as mstype | |||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell | |||||
from mindspore.parallel._utils import _get_parallel_mode | |||||
from mindspore.train.model import ParallelMode | |||||
from mindspore.train.amp import _do_keep_batchnorm_fp32 | |||||
from mindspore.train.amp import _add_loss_network | |||||
from mindspore import nn | |||||
from mindspore import context | |||||
from mindspore.ops import composite as C | |||||
from mindspore.ops import operations as P | |||||
from mindspore.ops import functional as F | |||||
from mindspore.parallel._utils import _get_gradients_mean | |||||
from mindspore.parallel._utils import _get_device_num | |||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
from mindspore.nn import Cell | |||||
from mindarmour.utils._check_param import check_param_type | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Mask model' | |||||
GRADIENT_CLIP_TYPE = 1 | |||||
_grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
_reciprocal = P.Reciprocal() | |||||
@_grad_scale.register("Tensor", "Tensor") | |||||
def tensor_grad_scale(scale, grad): | |||||
""" grad scaling """ | |||||
return grad*F.cast(_reciprocal(scale), F.dtype(grad)) | |||||
class SuppressModel(Model): | |||||
""" | |||||
This class is overload mindspore.train.model.Model. | |||||
Args: | |||||
network (Cell): The training network. | |||||
loss_fn (Cell): Computes softmax cross entropy between logits and labels. | |||||
optimizer (Optimizer): optimizer instance. | |||||
metrics (Union[dict, set]): Calculates the accuracy for classification and multilabel data. | |||||
kwargs: Keyword parameters used for creating a suppress model. | |||||
Examples: | |||||
networks_l5 = LeNet5() | |||||
masklayers = [] | |||||
masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
end_epoch=10, | |||||
batch_num=(int)(10000/cfg.batch_size), | |||||
start_epoch=3, | |||||
mask_times=100, | |||||
networks=networks_l5, | |||||
lr=lr, | |||||
sparse_end=0.90, | |||||
sparse_start=0.0, | |||||
mask_layers=masklayers) | |||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) | |||||
model_instance = SuppressModel(network=networks_l5, | |||||
loss_fn=net_loss, | |||||
optimizer=net_opt, | |||||
metrics={"Accuracy": Accuracy()}) | |||||
model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
directory="./trained_ckpt_file/", | |||||
config=config_ck) | |||||
model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
dataset_sink_mode=False) | |||||
""" | |||||
def __init__(self, | |||||
network=None, | |||||
**kwargs): | |||||
check_param_type('networks', network, Cell) | |||||
self.network_end = None | |||||
self._train_one_step = None | |||||
super(SuppressModel, self).__init__(network, **kwargs) | |||||
def link_suppress_ctrl(self, suppress_pri_ctrl): | |||||
""" | |||||
Link self and SuppressCtrl instance. | |||||
Args: | |||||
suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance. | |||||
""" | |||||
check_param_type('suppress_pri_ctrl', suppress_pri_ctrl, Cell) | |||||
if not isinstance(suppress_pri_ctrl, SuppressCtrl): | |||||
msg = "SuppressCtrl instance error!" | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
suppress_pri_ctrl.model = self | |||||
if self._train_one_step is not None: | |||||
self._train_one_step.link_suppress_ctrl(suppress_pri_ctrl) | |||||
def _build_train_network(self): | |||||
"""Build train network""" | |||||
network = self._network | |||||
ms_mode = context.get_context("mode") | |||||
if ms_mode != context.PYNATIVE_MODE: | |||||
raise ValueError("Only PYNATIVE_MODE is supported for suppress privacy now.") | |||||
if self._optimizer: | |||||
network = self._amp_build_train_network(network, | |||||
self._optimizer, | |||||
self._loss_fn, | |||||
level=self._amp_level, | |||||
keep_batchnorm_fp32=self._keep_bn_fp32) | |||||
else: | |||||
raise ValueError("_optimizer is none") | |||||
self._train_one_step = network | |||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, | |||||
ParallelMode.AUTO_PARALLEL): | |||||
network.set_auto_parallel() | |||||
self.network_end = self._train_one_step.network | |||||
return network | |||||
def _amp_build_train_network(self, network, optimizer, loss_fn=None, | |||||
level='O0', **kwargs): | |||||
""" | |||||
Build the mixed precision training cell automatically. | |||||
Args: | |||||
network (Cell): Definition of the network. | |||||
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, | |||||
the `network` should have the loss inside. Default: None. | |||||
optimizer (Optimizer): Optimizer to update the Parameter. | |||||
level (str): Supports [O0, O2]. Default: "O0". | |||||
- O0: Do not change. | |||||
- O2: Cast network to float16, keep batchnorm and `loss_fn` | |||||
(if set) run in float32, using dynamic loss scale. | |||||
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` | |||||
or `mstype.float32`. If set to `mstype.float16`, use `float16` | |||||
mode to train. If set, overwrite the level setting. | |||||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, | |||||
overwrite the level setting. | |||||
loss_scale_manager (Union[None, LossScaleManager]): If None, not | |||||
scale the loss, or else scale the loss by LossScaleManager. | |||||
If set, overwrite the level setting. | |||||
""" | |||||
validator.check_value_type('network', network, nn.Cell, None) | |||||
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) | |||||
validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) | |||||
self._check_kwargs(kwargs) | |||||
config = dict(_config_level[level], **kwargs) | |||||
config = edict(config) | |||||
if config.cast_model_type == mstype.float16: | |||||
network.to_float(mstype.float16) | |||||
if config.keep_batchnorm_fp32: | |||||
_do_keep_batchnorm_fp32(network) | |||||
if loss_fn: | |||||
network = _add_loss_network(network, loss_fn, | |||||
config.cast_model_type) | |||||
if _get_parallel_mode() in ( | |||||
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
network = _VirtualDatasetCell(network) | |||||
loss_scale = 1.0 | |||||
if config.loss_scale_manager is not None: | |||||
print("----model config have loss scale manager !") | |||||
network = TrainOneStepCell(network, optimizer, sens=loss_scale).set_train() | |||||
return network | |||||
class _TupleAdd(nn.Cell): | |||||
""" | |||||
Add two tuple of data. | |||||
""" | |||||
def __init__(self): | |||||
super(_TupleAdd, self).__init__() | |||||
self.add = P.TensorAdd() | |||||
self.hyper_map = C.HyperMap() | |||||
def construct(self, input1, input2): | |||||
"""Add two tuple of data.""" | |||||
out = self.hyper_map(self.add, input1, input2) | |||||
return out | |||||
class _TupleMul(nn.Cell): | |||||
""" | |||||
Mul two tuple of data. | |||||
""" | |||||
def __init__(self): | |||||
super(_TupleMul, self).__init__() | |||||
self.mul = P.Mul() | |||||
self.hyper_map = C.HyperMap() | |||||
def construct(self, input1, input2): | |||||
"""Add two tuple of data.""" | |||||
out = self.hyper_map(self.mul, input1, input2) | |||||
#print(out) | |||||
return out | |||||
# come from nn.cell_wrapper.TrainOneStepCell | |||||
class TrainOneStepCell(Cell): | |||||
r""" | |||||
Network training package class. | |||||
Wraps the network with an optimizer. The resulting Cell be trained with input data and label. | |||||
Backward graph will be created in the construct function to do parameter updating. Different | |||||
parallel modes are available to run the training. | |||||
Args: | |||||
network (Cell): The training network. | |||||
optimizer (Cell): Optimizer for updating the weights. | |||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||||
Inputs: | |||||
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||||
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||||
Outputs: | |||||
Tensor, a scalar Tensor with shape :math:`()`. | |||||
""" | |||||
def __init__(self, network, optimizer, sens=1.0): | |||||
super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||||
self.network = network | |||||
self.network.set_grad() | |||||
self.network.add_flags(defer_inline=True) | |||||
self.weights = optimizer.parameters | |||||
self.optimizer = optimizer | |||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True) # for mindspore 0.7x | |||||
self.sens = sens | |||||
self.reducer_flag = False | |||||
self.grad_reducer = None | |||||
self._tuple_add = _TupleAdd() | |||||
self._tuple_mul = _TupleMul() | |||||
parallel_mode = _get_parallel_mode() | |||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||||
self.reducer_flag = True | |||||
if self.reducer_flag: | |||||
mean = _get_gradients_mean() # for mindspore 0.7x | |||||
degree = _get_device_num() | |||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
self.do_privacy = False | |||||
self.grad_mask_tup = () # tuple containing grad_mask(cell) | |||||
self.de_weight_tup = () # tuple containing de_weight(cell) | |||||
self._suppress_pri_ctrl = None | |||||
def link_suppress_ctrl(self, suppress_pri_ctrl): | |||||
""" | |||||
Set Suppress Mask for grad_mask_tup and de_weight_tup. | |||||
Args: | |||||
suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance. | |||||
""" | |||||
self._suppress_pri_ctrl = suppress_pri_ctrl | |||||
if self._suppress_pri_ctrl.grads_mask_list: | |||||
for grad_mask_cell in self._suppress_pri_ctrl.grads_mask_list: | |||||
self.grad_mask_tup += (grad_mask_cell,) | |||||
self.do_privacy = True | |||||
for de_weight_cell in self._suppress_pri_ctrl.de_weight_mask_list: | |||||
self.de_weight_tup += (de_weight_cell,) | |||||
else: | |||||
self.do_privacy = False | |||||
def construct(self, data, label): | |||||
""" | |||||
Construct a compute flow. | |||||
""" | |||||
weights = self.weights | |||||
loss = self.network(data, label) | |||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||||
grads = self.grad(self.network, weights)(data, label, sens) | |||||
new_grads = () | |||||
m = 0 | |||||
for grad in grads: | |||||
if self.do_privacy and self._suppress_pri_ctrl.mask_started: | |||||
enable_mask, grad_mask = self.grad_mask_tup[m]() | |||||
enable_de_weight, de_weight_array = self.de_weight_tup[m]() | |||||
if enable_mask and enable_de_weight: | |||||
grad_n = self._tuple_add(de_weight_array, self._tuple_mul(grad, grad_mask)) | |||||
new_grads = new_grads + (grad_n,) | |||||
else: | |||||
new_grads = new_grads + (grad,) | |||||
else: | |||||
new_grads = new_grads + (grad,) | |||||
m = m + 1 | |||||
if self.reducer_flag: | |||||
new_grads = self.grad_reducer(new_grads) | |||||
return F.depend(loss, self.optimizer(new_grads)) |
@@ -1,4 +1,4 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | # | ||||
# Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
# you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
@@ -12,6 +12,6 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
""" | """ | ||||
This package includes unit tests for differential-privacy training and | |||||
privacy breach estimation. | |||||
This package includes unit tests for differential-privacy training, | |||||
suppress-privacy training and privacy breach estimation. | |||||
""" | """ |
@@ -0,0 +1,16 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
This package includes unit tests for suppress-privacy training. | |||||
""" |
@@ -0,0 +1,85 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Suppress Privacy model test. | |||||
""" | |||||
import pytest | |||||
import numpy as np | |||||
from mindspore import nn | |||||
from mindspore import context | |||||
from mindspore.train.callback import ModelCheckpoint | |||||
from mindspore.train.callback import CheckpointConfig | |||||
from mindspore.train.callback import LossMonitor | |||||
from mindspore.nn.metrics import Accuracy | |||||
import mindspore.dataset as ds | |||||
from ut.python.utils.mock_net import Net as LeNet5 | |||||
from mindarmour.privacy.sup_privacy import SuppressModel | |||||
from mindarmour.privacy.sup_privacy import SuppressMasker | |||||
from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory | |||||
from mindarmour.privacy.sup_privacy import MaskLayerDes | |||||
def dataset_generator(batch_size, batches): | |||||
"""mock training data.""" | |||||
data = np.random.random((batches*batch_size, 1, 32, 32)).astype( | |||||
np.float32) | |||||
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) | |||||
for i in range(batches): | |||||
yield data[i*batch_size:(i + 1)*batch_size],\ | |||||
label[i*batch_size:(i + 1)*batch_size] | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_suppress_model_with_pynative_mode(): | |||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
networks_l5 = LeNet5() | |||||
epochs = 5 | |||||
batch_num = 10 | |||||
batch_size = 32 | |||||
mask_times = 10 | |||||
lr = 0.01 | |||||
masklayers_lenet5 = [] | |||||
masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, False, -1)) | |||||
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
end_epoch=epochs, | |||||
batch_num=batch_num, | |||||
start_epoch=1, | |||||
mask_times=mask_times, | |||||
networks=networks_l5, | |||||
lr=lr, | |||||
sparse_end=0.50, | |||||
sparse_start=0.0, | |||||
mask_layers=masklayers_lenet5) | |||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
net_opt = nn.SGD(networks_l5.trainable_params(), lr) | |||||
model_instance = SuppressModel( | |||||
network=networks_l5, | |||||
loss_fn=net_loss, | |||||
optimizer=net_opt, | |||||
metrics={"Accuracy": Accuracy()}) | |||||
model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance) | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10) | |||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
directory="./trained_ckpt_file/", | |||||
config=config_ck) | |||||
ds_train = ds.GeneratorDataset(dataset_generator(batch_size, batch_num), ['data', 'label']) | |||||
model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
dataset_sink_mode=False) |