@@ -0,0 +1,132 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""Eval""" | |||
import os | |||
import argparse | |||
import datetime | |||
import mindspore.nn as nn | |||
from mindspore import context | |||
from mindspore.nn.optim.momentum import Momentum | |||
from mindspore.train.model import Model | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindspore.ops import operations as P | |||
from mindspore.ops import functional as F | |||
from mindspore.common import dtype as mstype | |||
from mindarmour.utils import LogUtil | |||
from vgg.vgg import vgg16 | |||
from vgg.dataset import vgg_create_dataset100 | |||
from vgg.config import cifar_cfg as cfg | |||
class ParameterReduce(nn.Cell): | |||
"""ParameterReduce""" | |||
def __init__(self): | |||
super(ParameterReduce, self).__init__() | |||
self.cast = P.Cast() | |||
self.reduce = P.AllReduce() | |||
def construct(self, x): | |||
one = self.cast(F.scalar_to_array(1.0), mstype.float32) | |||
out = x*one | |||
ret = self.reduce(out) | |||
return ret | |||
def parse_args(cloud_args=None): | |||
"""parse_args""" | |||
parser = argparse.ArgumentParser('mindspore classification test') | |||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
help='device where the code will be implemented. (Default: Ascend)') | |||
# dataset related | |||
parser.add_argument('--data_path', type=str, default='', help='eval data dir') | |||
parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu') | |||
# network related | |||
parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt') | |||
parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. ' | |||
'If it is a direction, it will test all ckpt') | |||
# logging related | |||
parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log') | |||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') | |||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||
args_opt = parser.parse_args() | |||
args_opt = merge_args(args_opt, cloud_args) | |||
args_opt.image_size = cfg.image_size | |||
args_opt.num_classes = cfg.num_classes | |||
args_opt.per_batch_size = cfg.batch_size | |||
args_opt.momentum = cfg.momentum | |||
args_opt.weight_decay = cfg.weight_decay | |||
args_opt.buffer_size = cfg.buffer_size | |||
args_opt.pad_mode = cfg.pad_mode | |||
args_opt.padding = cfg.padding | |||
args_opt.has_bias = cfg.has_bias | |||
args_opt.batch_norm = cfg.batch_norm | |||
args_opt.initialize_mode = cfg.initialize_mode | |||
args_opt.has_dropout = cfg.has_dropout | |||
args_opt.image_size = list(map(int, args_opt.image_size.split(','))) | |||
return args_opt | |||
def merge_args(args, cloud_args): | |||
"""merge_args""" | |||
args_dict = vars(args) | |||
if isinstance(cloud_args, dict): | |||
for key in cloud_args.keys(): | |||
val = cloud_args[key] | |||
if key in args_dict and val: | |||
arg_type = type(args_dict[key]) | |||
if arg_type is not type(None): | |||
val = arg_type(val) | |||
args_dict[key] = val | |||
return args | |||
def test(cloud_args=None): | |||
"""test""" | |||
args = parse_args(cloud_args) | |||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||
device_target=args.device_target, save_graphs=False) | |||
if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||
context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||
args.outputs_dir = os.path.join(args.log_path, | |||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
args.logger = LogUtil.get_instance() | |||
args.logger.set_level(20) | |||
net = vgg16(num_classes=args.num_classes, args=args) | |||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum, | |||
weight_decay=args.weight_decay) | |||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
param_dict = load_checkpoint(args.pre_trained) | |||
load_param_into_net(net, param_dict) | |||
net.set_train(False) | |||
dataset_test = vgg_create_dataset100(args.data_path, args.image_size, args.per_batch_size, training=False) | |||
res = model.eval(dataset_test) | |||
print("result: ", res) | |||
if __name__ == "__main__": | |||
test() |
@@ -0,0 +1,122 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
Examples of membership inference | |||
""" | |||
import argparse | |||
import sys | |||
from vgg.vgg import vgg16 | |||
from vgg.config import cifar_cfg as cfg | |||
from vgg.utils.util import get_param_groups | |||
from vgg.dataset import vgg_create_dataset100 | |||
import numpy as np | |||
from mindspore.train import Model | |||
from mindspore.train.serialization import load_param_into_net, load_checkpoint | |||
import mindspore.nn as nn | |||
from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference | |||
from mindarmour.utils import LogUtil | |||
logging = LogUtil.get_instance() | |||
logging.set_level(20) | |||
sys.path.append("../../") | |||
TAG = "membership inference example" | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser("main case arg parser.") | |||
parser.add_argument("--device_target", type=str, default="Ascend", | |||
choices=["Ascend"]) | |||
parser.add_argument("--data_path", type=str, required=True, | |||
help="Data home path for Cifar100.") | |||
parser.add_argument("--pre_trained", type=str, required=True, | |||
help="Checkpoint path.") | |||
args = parser.parse_args() | |||
args.num_classes = cfg.num_classes | |||
args.batch_norm = cfg.batch_norm | |||
args.has_dropout = cfg.has_dropout | |||
args.has_bias = cfg.has_bias | |||
args.initialize_mode = cfg.initialize_mode | |||
args.padding = cfg.padding | |||
args.pad_mode = cfg.pad_mode | |||
args.weight_decay = cfg.weight_decay | |||
args.loss_scale = cfg.loss_scale | |||
# load the pretrained model | |||
net = vgg16(args.num_classes, args) | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
opt = nn.Momentum(params=get_param_groups(net), learning_rate=0.1, momentum=0.9, | |||
weight_decay=args.weight_decay, loss_scale=args.loss_scale) | |||
load_param_into_net(net, load_checkpoint(args.pre_trained)) | |||
model = Model(network=net, loss_fn=loss, optimizer=opt) | |||
logging.info(TAG, "The model is loaded.") | |||
attacker = MembershipInference(model) | |||
config = [ | |||
{ | |||
"method": "knn", | |||
"params": { | |||
"n_neighbors": [3, 5, 7] | |||
} | |||
}, | |||
{ | |||
"method": "lr", | |||
"params": { | |||
"C": np.logspace(-4, 2, 10) | |||
} | |||
}, | |||
{ | |||
"method": "mlp", | |||
"params": { | |||
"hidden_layer_sizes": [(64,), (32, 32)], | |||
"solver": ["adam"], | |||
"alpha": [0.0001, 0.001, 0.01] | |||
} | |||
}, | |||
{ | |||
"method": "rf", | |||
"params": { | |||
"n_estimators": [100], | |||
"max_features": ["auto", "sqrt"], | |||
"max_depth": [5, 10, 20, None], | |||
"min_samples_split": [2, 5, 10], | |||
"min_samples_leaf": [1, 2, 4] | |||
} | |||
} | |||
] | |||
# load and split dataset | |||
train_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224), | |||
batch_size=64, num_samples=10000, shuffle=False) | |||
test_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224), | |||
batch_size=64, num_samples=10000, shuffle=False, training=False) | |||
train_train, eval_train = train_dataset.split([0.8, 0.2]) | |||
train_test, eval_test = test_dataset.split([0.8, 0.2]) | |||
logging.info(TAG, "Data loading is complete.") | |||
logging.info(TAG, "Start training the inference model.") | |||
attacker.train(train_train, train_test, config) | |||
logging.info(TAG, "The inference model is training complete.") | |||
logging.info(TAG, "Start the evaluation phase") | |||
metrics = ["precision", "accuracy", "recall"] | |||
result = attacker.eval(eval_train, eval_test, metrics) | |||
# Show the metrics for each attack method. | |||
count = len(config) | |||
for i in range(count): | |||
print("Method: {}, {}".format(config[i]["method"], result[i])) |
@@ -0,0 +1,198 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
#################train vgg16 example on cifar10######################## | |||
python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID | |||
""" | |||
import argparse | |||
import datetime | |||
import os | |||
import random | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore import Tensor | |||
from mindspore import context | |||
from mindspore.nn.optim.momentum import Momentum | |||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
from mindspore.train.model import Model | |||
from mindspore.train.serialization import load_param_into_net, load_checkpoint | |||
from mindarmour.utils import LogUtil | |||
from vgg.dataset import vgg_create_dataset100 | |||
from vgg.warmup_step_lr import warmup_step_lr | |||
from vgg.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr | |||
from vgg.warmup_step_lr import lr_steps | |||
from vgg.utils.util import get_param_groups | |||
from vgg.vgg import vgg16 | |||
from vgg.config import cifar_cfg as cfg | |||
TAG = "train" | |||
random.seed(1) | |||
np.random.seed(1) | |||
def parse_args(cloud_args=None): | |||
"""parameters""" | |||
parser = argparse.ArgumentParser('mindspore classification training') | |||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
help='device where the code will be implemented. (Default: Ascend)') | |||
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)') | |||
# dataset related | |||
parser.add_argument('--data_path', type=str, default='', help='train data dir') | |||
# network related | |||
parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load') | |||
parser.add_argument('--lr_gamma', type=float, default=0.1, | |||
help='decrease lr by a factor of exponential lr_scheduler') | |||
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') | |||
parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler') | |||
# logging and checkpoint related | |||
parser.add_argument('--log_interval', type=int, default=100, help='logging interval') | |||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') | |||
parser.add_argument('--ckpt_interval', type=int, default=2, help='ckpt_interval') | |||
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') | |||
args_opt = parser.parse_args() | |||
args_opt = merge_args(args_opt, cloud_args) | |||
args_opt.rank = 0 | |||
args_opt.group_size = 1 | |||
args_opt.label_smooth = cfg.label_smooth | |||
args_opt.label_smooth_factor = cfg.label_smooth_factor | |||
args_opt.lr_scheduler = cfg.lr_scheduler | |||
args_opt.loss_scale = cfg.loss_scale | |||
args_opt.max_epoch = cfg.max_epoch | |||
args_opt.warmup_epochs = cfg.warmup_epochs | |||
args_opt.lr = cfg.lr | |||
args_opt.lr_init = cfg.lr_init | |||
args_opt.lr_max = cfg.lr_max | |||
args_opt.momentum = cfg.momentum | |||
args_opt.weight_decay = cfg.weight_decay | |||
args_opt.per_batch_size = cfg.batch_size | |||
args_opt.num_classes = cfg.num_classes | |||
args_opt.buffer_size = cfg.buffer_size | |||
args_opt.ckpt_save_max = cfg.keep_checkpoint_max | |||
args_opt.pad_mode = cfg.pad_mode | |||
args_opt.padding = cfg.padding | |||
args_opt.has_bias = cfg.has_bias | |||
args_opt.batch_norm = cfg.batch_norm | |||
args_opt.initialize_mode = cfg.initialize_mode | |||
args_opt.has_dropout = cfg.has_dropout | |||
args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(','))) | |||
args_opt.image_size = list(map(int, cfg.image_size.split(','))) | |||
return args_opt | |||
def merge_args(args_opt, cloud_args): | |||
"""dictionary""" | |||
args_dict = vars(args_opt) | |||
if isinstance(cloud_args, dict): | |||
for key_arg in cloud_args.keys(): | |||
val = cloud_args[key_arg] | |||
if key_arg in args_dict and val: | |||
arg_type = type(args_dict[key_arg]) | |||
if arg_type is not None: | |||
val = arg_type(val) | |||
args_dict[key_arg] = val | |||
return args_opt | |||
if __name__ == '__main__': | |||
args = parse_args() | |||
device_num = int(os.environ.get("DEVICE_NUM", 1)) | |||
context.set_context(device_id=args.device_id) | |||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
# select for master rank save ckpt or all rank save, compatiable for model parallel | |||
args.rank_save_ckpt_flag = 0 | |||
if args.is_save_on_master: | |||
if args.rank == 0: | |||
args.rank_save_ckpt_flag = 1 | |||
else: | |||
args.rank_save_ckpt_flag = 1 | |||
# logger | |||
args.outputs_dir = os.path.join(args.ckpt_path, | |||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
args.logger = LogUtil.get_instance() | |||
args.logger.set_level(20) | |||
# load train data set | |||
dataset = vgg_create_dataset100(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size) | |||
batch_num = dataset.get_dataset_size() | |||
args.steps_per_epoch = dataset.get_dataset_size() | |||
# network | |||
args.logger.info(TAG, 'start create network') | |||
# get network and init | |||
network = vgg16(args.num_classes, args) | |||
# pre_trained | |||
if args.pre_trained: | |||
load_param_into_net(network, load_checkpoint(args.pre_trained)) | |||
# lr scheduler | |||
if args.lr_scheduler == 'exponential': | |||
lr = warmup_step_lr(args.lr, | |||
args.lr_epochs, | |||
args.steps_per_epoch, | |||
args.warmup_epochs, | |||
args.max_epoch, | |||
gamma=args.lr_gamma, | |||
) | |||
elif args.lr_scheduler == 'cosine_annealing': | |||
lr = warmup_cosine_annealing_lr(args.lr, | |||
args.steps_per_epoch, | |||
args.warmup_epochs, | |||
args.max_epoch, | |||
args.T_max, | |||
args.eta_min) | |||
elif args.lr_scheduler == 'step': | |||
lr = lr_steps(0, lr_init=args.lr_init, lr_max=args.lr_max, warmup_epochs=args.warmup_epochs, | |||
total_epochs=args.max_epoch, steps_per_epoch=batch_num) | |||
else: | |||
raise NotImplementedError(args.lr_scheduler) | |||
# optimizer | |||
opt = Momentum(params=get_param_groups(network), | |||
learning_rate=Tensor(lr), | |||
momentum=args.momentum, | |||
weight_decay=args.weight_decay, | |||
loss_scale=args.loss_scale) | |||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||
# checkpoint save | |||
if args.rank_save_ckpt_flag: | |||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch, | |||
keep_checkpoint_max=args.ckpt_save_max) | |||
ckpt_cb = ModelCheckpoint(config=ckpt_config, | |||
directory=args.outputs_dir, | |||
prefix='{}'.format(args.rank)) | |||
callbacks = ckpt_cb | |||
model.train(args.max_epoch, dataset, callbacks=callbacks) |
@@ -0,0 +1,14 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the License); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# httpwww.apache.orglicensesLICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an AS IS BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ |
@@ -0,0 +1,45 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
network config setting, will be used in train.py and eval.py | |||
""" | |||
from easydict import EasyDict as edict | |||
# config for vgg16, cifar100 | |||
cifar_cfg = edict({ | |||
"num_classes": 100, | |||
"lr": 0.01, | |||
"lr_init": 0.01, | |||
"lr_max": 0.1, | |||
"lr_epochs": '30,60,90,120', | |||
"lr_scheduler": "step", | |||
"warmup_epochs": 5, | |||
"batch_size": 64, | |||
"max_epoch": 100, | |||
"momentum": 0.9, | |||
"weight_decay": 5e-4, | |||
"loss_scale": 1.0, | |||
"label_smooth": 0, | |||
"label_smooth_factor": 0, | |||
"buffer_size": 10, | |||
"image_size": '224,224', | |||
"pad_mode": 'same', | |||
"padding": 0, | |||
"has_bias": False, | |||
"batch_norm": True, | |||
"keep_checkpoint_max": 10, | |||
"initialize_mode": "XavierUniform", | |||
"has_dropout": False | |||
}) |
@@ -0,0 +1,39 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""define loss function for network""" | |||
from mindspore.nn.loss.loss import _Loss | |||
from mindspore.ops import operations as P | |||
from mindspore.ops import functional as F | |||
from mindspore import Tensor | |||
from mindspore.common import dtype as mstype | |||
import mindspore.nn as nn | |||
class CrossEntropy(_Loss): | |||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits""" | |||
def __init__(self, smooth_factor=0., num_classes=1001): | |||
super(CrossEntropy, self).__init__() | |||
self.onehot = P.OneHot() | |||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||
self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||
self.mean = P.ReduceMean(False) | |||
def construct(self, logit, label): | |||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) | |||
loss = self.ce(logit, one_hot_label) | |||
loss = self.mean(loss, 0) | |||
return loss |
@@ -0,0 +1,75 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
dataset processing. | |||
""" | |||
import os | |||
from mindspore.common import dtype as mstype | |||
import mindspore.dataset as de | |||
import mindspore.dataset.transforms.c_transforms as C | |||
import mindspore.dataset.transforms.vision.c_transforms as vision | |||
def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, | |||
training=True, num_samples=None, shuffle=True): | |||
"""Data operations.""" | |||
de.config.set_seed(1) | |||
data_dir = os.path.join(data_home, "train") | |||
if not training: | |||
data_dir = os.path.join(data_home, "test") | |||
if num_samples is not None: | |||
data_set = de.Cifar100Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, | |||
num_samples=num_samples, shuffle=shuffle) | |||
else: | |||
data_set = de.Cifar100Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) | |||
input_columns = ["fine_label"] | |||
output_columns = ["label"] | |||
data_set = data_set.rename(input_columns=input_columns, output_columns=output_columns) | |||
data_set = data_set.project(["image", "label"]) | |||
rescale = 1.0 / 255.0 | |||
shift = 0.0 | |||
# define map operations | |||
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT | |||
random_horizontal_op = vision.RandomHorizontalFlip() | |||
resize_op = vision.Resize(image_size) # interpolation default BILINEAR | |||
rescale_op = vision.Rescale(rescale, shift) | |||
normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) | |||
changeswap_op = vision.HWC2CHW() | |||
type_cast_op = C.TypeCast(mstype.int32) | |||
c_trans = [] | |||
if training: | |||
c_trans = [random_crop_op, random_horizontal_op] | |||
c_trans += [resize_op, rescale_op, normalize_op, | |||
changeswap_op] | |||
# apply map operations on images | |||
data_set = data_set.map(input_columns="label", operations=type_cast_op) | |||
data_set = data_set.map(input_columns="image", operations=c_trans) | |||
# apply repeat operations | |||
data_set = data_set.repeat(repeat_num) | |||
# apply shuffle operations | |||
# data_set = data_set.shuffle(buffer_size=1000) | |||
# apply batch operations | |||
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) | |||
return data_set |
@@ -0,0 +1,23 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
linear warm up learning rate. | |||
""" | |||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): | |||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||
lr = float(init_lr) + lr_inc*current_step | |||
return lr |
@@ -0,0 +1,36 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""Util class or function.""" | |||
def get_param_groups(network): | |||
"""Param groups for optimizer.""" | |||
decay_params = [] | |||
no_decay_params = [] | |||
for x in network.trainable_params(): | |||
parameter_name = x.name | |||
if parameter_name.endswith('.bias'): | |||
# all bias not using weight decay | |||
no_decay_params.append(x) | |||
elif parameter_name.endswith('.gamma'): | |||
# bn weight bias not using weight decay, be carefully for now x not include BN | |||
no_decay_params.append(x) | |||
elif parameter_name.endswith('.beta'): | |||
# bn weight bias not using weight decay, be carefully for now x not include BN | |||
no_decay_params.append(x) | |||
else: | |||
decay_params.append(x) | |||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] |
@@ -0,0 +1,214 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
Initialize. | |||
""" | |||
import math | |||
from functools import reduce | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore.common import initializer as init | |||
def _calculate_gain(nonlinearity, param=None): | |||
r""" | |||
Return the recommended gain value for the given nonlinearity function. | |||
The values are as follows: | |||
================= ==================================================== | |||
nonlinearity gain | |||
================= ==================================================== | |||
Linear / Identity :math:`1` | |||
Conv{1,2,3}D :math:`1` | |||
Sigmoid :math:`1` | |||
Tanh :math:`\frac{5}{3}` | |||
ReLU :math:`\sqrt{2}` | |||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` | |||
================= ==================================================== | |||
Args: | |||
nonlinearity: the non-linear function | |||
param: optional parameter for the non-linear function | |||
Examples: | |||
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 | |||
""" | |||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] | |||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid': | |||
return 1 | |||
if nonlinearity == 'tanh': | |||
return 5.0 / 3 | |||
if nonlinearity == 'relu': | |||
return math.sqrt(2.0) | |||
if nonlinearity == 'leaky_relu': | |||
if param is None: | |||
negative_slope = 0.01 | |||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): | |||
negative_slope = param | |||
else: | |||
raise ValueError("negative_slope {} not a valid number".format(param)) | |||
return math.sqrt(2.0 / (1 + negative_slope**2)) | |||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||
def _assignment(arr, num): | |||
"""Assign the value of `num` to `arr`.""" | |||
if arr.shape == (): | |||
arr = arr.reshape((1)) | |||
arr[:] = num | |||
arr = arr.reshape(()) | |||
else: | |||
if isinstance(num, np.ndarray): | |||
arr[:] = num[:] | |||
else: | |||
arr[:] = num | |||
return arr | |||
def _calculate_in_and_out(arr): | |||
""" | |||
Calculate n_in and n_out. | |||
Args: | |||
arr (Array): Input array. | |||
Returns: | |||
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. | |||
""" | |||
dim = len(arr.shape) | |||
if dim < 2: | |||
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") | |||
n_in = arr.shape[1] | |||
n_out = arr.shape[0] | |||
if dim > 2: | |||
counter = reduce(lambda x, y: x*y, arr.shape[2:]) | |||
n_in *= counter | |||
n_out *= counter | |||
return n_in, n_out | |||
def _select_fan(array, mode): | |||
mode = mode.lower() | |||
valid_modes = ['fan_in', 'fan_out'] | |||
if mode not in valid_modes: | |||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) | |||
fan_in, fan_out = _calculate_in_and_out(array) | |||
return fan_in if mode == 'fan_in' else fan_out | |||
class KaimingInit(init.Initializer): | |||
r""" | |||
Base Class. Initialize the array with He kaiming algorithm. | |||
Args: | |||
a: the negative slope of the rectifier used after this layer (only | |||
used with ``'leaky_relu'``) | |||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` | |||
preserves the magnitude of the variance of the weights in the | |||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the | |||
backwards pass. | |||
nonlinearity: the non-linear function, recommended to use only with | |||
``'relu'`` or ``'leaky_relu'`` (default). | |||
""" | |||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): | |||
super(KaimingInit, self).__init__() | |||
self.mode = mode | |||
self.gain = _calculate_gain(nonlinearity, a) | |||
def _initialize(self, arr): | |||
pass | |||
class KaimingUniform(KaimingInit): | |||
r""" | |||
Initialize the array with He kaiming uniform algorithm. The resulting tensor will | |||
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |||
.. math:: | |||
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} | |||
Input: | |||
arr (Array): The array to be assigned. | |||
Returns: | |||
Array, assigned array. | |||
Examples: | |||
>>> w = np.empty(3, 5) | |||
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu') | |||
""" | |||
def _initialize(self, arr): | |||
fan = _select_fan(arr, self.mode) | |||
bound = math.sqrt(3.0)*self.gain / math.sqrt(fan) | |||
np.random.seed(0) | |||
data = np.random.uniform(-bound, bound, arr.shape) | |||
_assignment(arr, data) | |||
class KaimingNormal(KaimingInit): | |||
r""" | |||
Initialize the array with He kaiming normal algorithm. The resulting tensor will | |||
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where | |||
.. math:: | |||
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} | |||
Input: | |||
arr (Array): The array to be assigned. | |||
Returns: | |||
Array, assigned array. | |||
Examples: | |||
>>> w = np.empty(3, 5) | |||
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu') | |||
""" | |||
def _initialize(self, arr): | |||
fan = _select_fan(arr, self.mode) | |||
std = self.gain / math.sqrt(fan) | |||
np.random.seed(0) | |||
data = np.random.normal(0, std, arr.shape) | |||
_assignment(arr, data) | |||
def default_recurisive_init(custom_cell): | |||
"""default_recurisive_init""" | |||
for _, cell in custom_cell.cells_and_names(): | |||
if isinstance(cell, nn.Conv2d): | |||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
cell.weight.shape, | |||
cell.weight.dtype) | |||
if cell.bias is not None: | |||
fan_in, _ = _calculate_in_and_out(cell.weight) | |||
bound = 1 / math.sqrt(fan_in) | |||
np.random.seed(0) | |||
cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
cell.bias.shape, | |||
cell.bias.dtype) | |||
elif isinstance(cell, nn.Dense): | |||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
cell.weight.shape, | |||
cell.weight.dtype) | |||
if cell.bias is not None: | |||
fan_in, _ = _calculate_in_and_out(cell.weight) | |||
bound = 1 / math.sqrt(fan_in) | |||
np.random.seed(0) | |||
cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
cell.bias.shape, | |||
cell.bias.dtype) | |||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||
pass |
@@ -0,0 +1,142 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
Image classifiation. | |||
""" | |||
import math | |||
import mindspore.nn as nn | |||
import mindspore.common.dtype as mstype | |||
from mindspore.common import initializer as init | |||
from mindspore.common.initializer import initializer | |||
from .utils.var_init import default_recurisive_init, KaimingNormal | |||
def _make_layer(base, args, batch_norm): | |||
"""Make stage network of VGG.""" | |||
layers = [] | |||
in_channels = 3 | |||
for v in base: | |||
if v == 'M': | |||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | |||
else: | |||
weight_shape = (v, in_channels, 3, 3) | |||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() | |||
if args.initialize_mode == "KaimingNormal": | |||
weight = 'normal' | |||
conv2d = nn.Conv2d(in_channels=in_channels, | |||
out_channels=v, | |||
kernel_size=3, | |||
padding=args.padding, | |||
pad_mode=args.pad_mode, | |||
has_bias=args.has_bias, | |||
weight_init=weight) | |||
if batch_norm: | |||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] | |||
else: | |||
layers += [conv2d, nn.ReLU()] | |||
in_channels = v | |||
return nn.SequentialCell(layers) | |||
class Vgg(nn.Cell): | |||
""" | |||
VGG network definition. | |||
Args: | |||
base (list): Configuration for different layers, mainly the channel number of Conv layer. | |||
num_classes (int): Class numbers. Default: 1000. | |||
batch_norm (bool): Whether to do the batchnorm. Default: False. | |||
batch_size (int): Batch size. Default: 1. | |||
Returns: | |||
Tensor, infer output tensor. | |||
Examples: | |||
>>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], | |||
>>> num_classes=1000, batch_norm=False, batch_size=1) | |||
""" | |||
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"): | |||
super(Vgg, self).__init__() | |||
_ = batch_size | |||
self.layers = _make_layer(base, args, batch_norm=batch_norm) | |||
self.flatten = nn.Flatten() | |||
dropout_ratio = 0.5 | |||
if not args.has_dropout or phase == "test": | |||
dropout_ratio = 1.0 | |||
self.classifier = nn.SequentialCell([ | |||
nn.Dense(512*7*7, 4096), | |||
nn.ReLU(), | |||
nn.Dropout(dropout_ratio), | |||
nn.Dense(4096, 4096), | |||
nn.ReLU(), | |||
nn.Dropout(dropout_ratio), | |||
nn.Dense(4096, num_classes)]) | |||
if args.initialize_mode == "KaimingNormal": | |||
default_recurisive_init(self) | |||
self.custom_init_weight() | |||
def construct(self, x): | |||
x = self.layers(x) | |||
x = self.flatten(x) | |||
x = self.classifier(x) | |||
return x | |||
def custom_init_weight(self): | |||
""" | |||
Init the weight of Conv2d and Dense in the net. | |||
""" | |||
for _, cell in self.cells_and_names(): | |||
if isinstance(cell, nn.Conv2d): | |||
cell.weight.default_input = init.initializer( | |||
KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), | |||
cell.weight.shape, cell.weight.dtype) | |||
if cell.bias is not None: | |||
cell.bias.default_input = init.initializer( | |||
'zeros', cell.bias.shape, cell.bias.dtype) | |||
elif isinstance(cell, nn.Dense): | |||
cell.weight.default_input = init.initializer( | |||
init.Normal(0.01), cell.weight.shape, cell.weight.dtype) | |||
if cell.bias is not None: | |||
cell.bias.default_input = init.initializer( | |||
'zeros', cell.bias.shape, cell.bias.dtype) | |||
cfg = { | |||
'11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], | |||
'13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], | |||
'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], | |||
'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], | |||
} | |||
def vgg16(num_classes=1000, args=None, phase="train"): | |||
""" | |||
Get Vgg16 neural network with batch normalization. | |||
Args: | |||
num_classes (int): Class numbers. Default: 1000. | |||
args(namespace): param for net init. | |||
phase(str): train or test mode. | |||
Returns: | |||
Cell, cell instance of Vgg16 neural network with batch normalization. | |||
Examples: | |||
>>> vgg16(num_classes=1000, args=args) | |||
""" | |||
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) | |||
return net |
@@ -0,0 +1,40 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
warm up cosine annealing learning rate. | |||
""" | |||
import math | |||
import numpy as np | |||
from .linear_warmup import linear_warmup_lr | |||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0): | |||
"""warm up cosine annealing learning rate.""" | |||
base_lr = lr | |||
warmup_init_lr = 0 | |||
total_steps = int(max_epoch*steps_per_epoch) | |||
warmup_steps = int(warmup_epochs*steps_per_epoch) | |||
lr_each_step = [] | |||
for i in range(total_steps): | |||
last_epoch = i // steps_per_epoch | |||
if i < warmup_steps: | |||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||
else: | |||
lr = eta_min + (base_lr - eta_min)*(1. + math.cos(math.pi*last_epoch / t_max)) / 2 | |||
lr_each_step.append(lr) | |||
return np.array(lr_each_step).astype(np.float32) |
@@ -0,0 +1,84 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
warm up step learning rate. | |||
""" | |||
from collections import Counter | |||
import numpy as np | |||
from .linear_warmup import linear_warmup_lr | |||
def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): | |||
"""Set learning rate.""" | |||
lr_each_step = [] | |||
total_steps = steps_per_epoch*total_epochs | |||
warmup_steps = steps_per_epoch*warmup_epochs | |||
if warmup_steps != 0: | |||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) | |||
else: | |||
inc_each_step = 0 | |||
for i in range(total_steps): | |||
if i < warmup_steps: | |||
lr_value = float(lr_init) + inc_each_step*float(i) | |||
else: | |||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) | |||
lr_value = float(lr_max)*base*base | |||
if lr_value < 0.0: | |||
lr_value = 0.0 | |||
lr_each_step.append(lr_value) | |||
current_step = global_step | |||
lr_each_step = np.array(lr_each_step).astype(np.float32) | |||
learning_rate = lr_each_step[current_step:] | |||
return learning_rate | |||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): | |||
"""warmup_step_lr""" | |||
base_lr = lr | |||
warmup_init_lr = 0 | |||
total_steps = int(max_epoch*steps_per_epoch) | |||
warmup_steps = int(warmup_epochs*steps_per_epoch) | |||
milestones = lr_epochs | |||
milestones_steps = [] | |||
for milestone in milestones: | |||
milestones_step = milestone*steps_per_epoch | |||
milestones_steps.append(milestones_step) | |||
lr_each_step = [] | |||
lr = base_lr | |||
milestones_steps_counter = Counter(milestones_steps) | |||
for i in range(total_steps): | |||
if i < warmup_steps: | |||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||
else: | |||
lr = lr*gamma**milestones_steps_counter[i] | |||
lr_each_step.append(lr) | |||
return np.array(lr_each_step).astype(np.float32) | |||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): | |||
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) | |||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): | |||
lr_epochs = [] | |||
for i in range(1, max_epoch): | |||
if i % epoch_size == 0: | |||
lr_epochs.append(i) | |||
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) |
@@ -32,7 +32,7 @@ def _attack_knn(features, labels, param_grid): | |||
param_grid (dict): Setting of GridSearchCV. | |||
Returns: | |||
sklearn.neighbors.KNeighborsClassifier, trained model. | |||
sklearn.model_selection.GridSearchCV, trained model. | |||
""" | |||
knn_model = KNeighborsClassifier() | |||
knn_model = GridSearchCV( | |||
@@ -53,9 +53,9 @@ def _attack_lr(features, labels, param_grid): | |||
param_grid (dict): Setting of GridSearchCV. | |||
Returns: | |||
sklearn.linear_model.LogisticRegression, trained model. | |||
sklearn.model_selection.GridSearchCV, trained model. | |||
""" | |||
lr_model = LogisticRegression(C=1.0, penalty="l2") | |||
lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=1000) | |||
lr_model = GridSearchCV( | |||
lr_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, | |||
verbose=0, | |||
@@ -74,7 +74,7 @@ def _attack_mlpc(features, labels, param_grid): | |||
param_grid (dict): Setting of GridSearchCV. | |||
Returns: | |||
sklearn.neural_network.MLPClassifier, trained model. | |||
sklearn.model_selection.GridSearchCV, trained model. | |||
""" | |||
mlpc_model = MLPClassifier(random_state=1, max_iter=300) | |||
mlpc_model = GridSearchCV( | |||
@@ -95,7 +95,7 @@ def _attack_rf(features, labels, random_grid): | |||
random_grid (dict): Setting of RandomizedSearchCV. | |||
Returns: | |||
sklearn.ensemble.RandomForestClassifier, trained model. | |||
sklearn.model_selection.RandomizedSearchCV, trained model. | |||
""" | |||
rf_model = RandomForestClassifier(max_depth=2, random_state=0) | |||
rf_model = RandomizedSearchCV( | |||
@@ -0,0 +1,197 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Membership Inference | |||
""" | |||
import numpy as np | |||
import mindspore as ms | |||
from mindspore.train import Model | |||
import mindspore.nn as nn | |||
import mindspore.context as context | |||
from mindspore import Tensor | |||
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model | |||
def _eval_info(pred, truth, option): | |||
""" | |||
Calculate the performance according to pred and truth. | |||
Args: | |||
pred (numpy.ndarray): Predictions for each sample. | |||
truth (numpy.ndarray): Ground truth for each sample. | |||
option(str): Type of evaluation indicators; Possible | |||
values are 'precision', 'accuracy' and 'recall'. | |||
Returns: | |||
float32, Calculated evaluation results. | |||
Raises: | |||
ValueError, size of parameter pred or truth is 0. | |||
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. | |||
""" | |||
if pred.size == 0 || truth.size == 0: | |||
raise ValueError("Size of pred or truth is 0.") | |||
if option == "accuracy": | |||
count = np.sum(pred == truth) | |||
return count / len(pred) | |||
if option == "precision": | |||
count = np.sum(pred & truth) | |||
if np.sum(pred) == 0: | |||
return -1 | |||
return count / np.sum(pred) | |||
if option == "recall": | |||
count = np.sum(pred & truth) | |||
if np.sum(truth) == 0: | |||
return -1 | |||
return count / np.sum(truth) | |||
raise ValueError("The metric value {} is undefined.".format(option)) | |||
class MembershipInference: | |||
""" | |||
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. | |||
The attack requires obtain loss or logits results of training samples. | |||
References: Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. | |||
Membership Inference Attacks against Machine Learning Models. 2017. | |||
arXiv:1610.05820v2 <https://arxiv.org/abs/1610.05820v2>`_ | |||
Args: | |||
model (Model): Target model. | |||
Examples: | |||
>>> # ds_train, eval_train are non-overlapping datasets from training dataset. | |||
>>> # eval_train, eval_test are non-overlapping datasets from test dataset. | |||
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | |||
>>> inference_model = MembershipInference(model) | |||
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | |||
>>> inference_model.train(ds_train, ds_test, config) | |||
>>> metrics = ["precision", "recall", "accuracy"] | |||
>>> result = inference_model.eval(eval_train, eval_test, metrics) | |||
Raises: | |||
TypeError: If type of model is not mindspore.train.Model. | |||
""" | |||
def __init__(self, model): | |||
if not isinstance(model, Model): | |||
raise TypeError("Type of model must be {}, but got {}.".format(type(Model), type(model))) | |||
self.model = model | |||
self.attack_list = [] | |||
def train(self, dataset_train, dataset_test, attack_config): | |||
""" | |||
Depending on the configuration, use the incoming data set to train the attack model. | |||
Save the attack model to self.attack_list. | |||
Args: | |||
dataset_train (mindspore.dataset): The training dataset for the target model. | |||
dataset_test (mindspore.dataset): The test set for the target model. | |||
attack_config (list): Parameter setting for the attack model. | |||
Raises: | |||
ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"]. | |||
""" | |||
features, labels = self._transform(dataset_train, dataset_test) | |||
for config in attack_config: | |||
self.attack_list.append(get_attack_model(features, labels, config)) | |||
def eval(self, dataset_train, dataset_test, metrics): | |||
""" | |||
Evaluate the different privacy of the target model. | |||
Evaluation indicators shall be specified by metrics. | |||
Args: | |||
dataset_train (mindspore.dataset): The training dataset for the target model. | |||
dataset_test (mindspore.dataset): The test dataset for the target model. | |||
metrics (Union[list, tuple]): Evaluation indicators. The value of metrics | |||
must be in ["precision", "accuracy", "recall"]. Default: ["precision"]. | |||
Returns: | |||
list, Each element contains an evaluation indicator for the attack model. | |||
""" | |||
result = [] | |||
features, labels = self._transform(dataset_train, dataset_test) | |||
for attacker in self.attack_list: | |||
pred = attacker.predict(features) | |||
item = {} | |||
for option in metrics: | |||
item[option] = _eval_info(pred, labels, option) | |||
result.append(item) | |||
return result | |||
def _transform(self, dataset_train, dataset_test): | |||
""" | |||
Generate corresponding loss_logits feature and new label, and return after shuffle. | |||
Args: | |||
dataset_train: The training set for the target model. | |||
dataset_test: The test set for the target model. | |||
Returns: | |||
- numpy.ndarray, Loss_logits features for each sample. Shape is (N, C). | |||
N is the number of sample. C = 1 + dim(logits). | |||
- numpy.ndarray, Labels for each sample, Shape is (N,). | |||
""" | |||
features_train, labels_train = self._generate(dataset_train, 1) | |||
features_test, labels_test = self._generate(dataset_test, 0) | |||
features = np.vstack((features_train, features_test)) | |||
labels = np.hstack((labels_train, labels_test)) | |||
shuffle_index = np.array(range(len(labels))) | |||
np.random.shuffle(shuffle_index) | |||
features = features[shuffle_index] | |||
labels = labels[shuffle_index] | |||
return features, labels | |||
def _generate(self, dataset_x, label): | |||
""" | |||
Return a loss_logits features and labels for training attack model. | |||
Args: | |||
dataset_x (mindspore.dataset): The dataset to be generate. | |||
label (int32): Whether dataset_x belongs to the target model. | |||
Returns: | |||
- numpy.ndarray, Loss_logits features for each sample. Shape is (N, C). | |||
N is the number of sample. C = 1 + dim(logits). | |||
- numpy.ndarray, Labels for each sample, Shape is (N,). | |||
""" | |||
if context.get_context("device_target") != "Ascend": | |||
raise RuntimeError("The target device must be Ascend, " | |||
"but current is {}.".format(context.get_context("device_target"))) | |||
loss_logits = np.array([]) | |||
for batch in dataset_x.create_dict_iterator(): | |||
batch_data = Tensor(batch['image'], ms.float32) | |||
batch_labels = Tensor(batch['label'], ms.int32) | |||
batch_logits = self.model.predict(batch_data) | |||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None) | |||
batch_loss = loss(batch_logits, batch_labels).asnumpy() | |||
batch_logits = batch_logits.asnumpy() | |||
batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) | |||
if loss_logits.size == 0: | |||
loss_logits = batch_feature | |||
else: | |||
loss_logits = np.vstack((loss_logits, batch_feature)) | |||
if label == 1: | |||
labels = np.ones(len(loss_logits), np.int32) | |||
elif label == 0: | |||
labels = np.zeros(len(loss_logits), np.int32) | |||
else: | |||
raise ValueError("The value of label must be 0 or 1, but got {}.".format(label)) | |||
return loss_logits, labels |
@@ -0,0 +1,111 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
membership inference test | |||
""" | |||
import os | |||
import sys | |||
import pytest | |||
import numpy as np | |||
import mindspore.dataset as ds | |||
from mindspore import nn | |||
from mindspore.train import Model | |||
from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference | |||
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) | |||
from defenses.mock_net import Net | |||
def dataset_generator(batch_size, batches): | |||
"""mock training data.""" | |||
data = np.random.randn(batches*batch_size, 1, 32, 32).astype( | |||
np.float32) | |||
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) | |||
for i in range(batches): | |||
yield data[i*batch_size:(i + 1)*batch_size],\ | |||
label[i*batch_size:(i + 1)*batch_size] | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_get_membership_inference_object(): | |||
net = Net() | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
model = Model(network=net, loss_fn=loss, optimizer=opt) | |||
inference_model = MembershipInference(model) | |||
assert isinstance(inference_model, MembershipInference) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_membership_inference_object_train(): | |||
net = Net() | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
model = Model(network=net, loss_fn=loss, optimizer=opt) | |||
inference_model = MembershipInference(model) | |||
assert isinstance(inference_model, MembershipInference) | |||
config = [{ | |||
"method": "KNN", | |||
"params": { | |||
"n_neighbors": [3, 5, 7], | |||
} | |||
}] | |||
batch_size = 16 | |||
batches = 1 | |||
ds_train = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||
["image", "label"]) | |||
ds_test = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||
["image", "label"]) | |||
ds_train.set_dataset_size(batch_size*batches) | |||
ds_test.set_dataset_size((batch_size*batches)) | |||
inference_model.train(ds_train, ds_test, config) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_membership_inference_eval(): | |||
net = Net() | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
model = Model(network=net, loss_fn=loss, optimizer=opt) | |||
inference_model = MembershipInference(model) | |||
assert isinstance(inference_model, MembershipInference) | |||
batch_size = 16 | |||
batches = 1 | |||
eval_train = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||
["image", "label"]) | |||
eval_test = ds.GeneratorDataset(dataset_generator(batch_size, batches), | |||
["image", "label"]) | |||
eval_train.set_dataset_size(batch_size * batches) | |||
eval_test.set_dataset_size((batch_size * batches)) | |||
metrics = ["precision", "accuracy", "recall"] | |||
inference_model.eval(eval_train, eval_test, metrics) |