@@ -31,7 +31,7 @@ mnist_cfg = edict({ | |||||
'device_target': 'Ascend', # device used | 'device_target': 'Ascend', # device used | ||||
'data_path': './MNIST_unzip', # the path of training and testing data set | 'data_path': './MNIST_unzip', # the path of training and testing data set | ||||
'dataset_sink_mode': False, # whether deliver all training data to device one time | 'dataset_sink_mode': False, # whether deliver all training data to device one time | ||||
'micro_batches': 32, # the number of small batches split from an original batch | |||||
'micro_batches': 16, # the number of small batches split from an original batch | |||||
'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | 'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | ||||
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training | 'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training | ||||
# parameters' gradients | # parameters' gradients | ||||
@@ -124,8 +124,7 @@ if __name__ == "__main__": | |||||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | rdp_monitor = PrivacyMonitorFactory.create('rdp', | ||||
num_samples=60000, | num_samples=60000, | ||||
batch_size=cfg.batch_size, | batch_size=cfg.batch_size, | ||||
initial_noise_multiplier=cfg.initial_noise_multiplier* | |||||
cfg.l2_norm_bound, | |||||
initial_noise_multiplier=cfg.initial_noise_multiplier, | |||||
per_print_times=50) | per_print_times=50) | ||||
# Create the DP model for training. | # Create the DP model for training. | ||||
@@ -54,7 +54,7 @@ def test_ad(): | |||||
net = Net() | net = Net() | ||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=sparse) | loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=sparse) | ||||
optimizer = Momentum(learning_rate=Tensor(np.array([0.001], np.float32)), | optimizer = Momentum(learning_rate=Tensor(np.array([0.001], np.float32)), | ||||
momentum=Tensor(np.array([0.9], np.float32)), | |||||
momentum=0.9, | |||||
params=net.trainable_params()) | params=net.trainable_params()) | ||||
ad_defense = AdversarialDefense(net, loss_fn=loss_fn, optimizer=optimizer) | ad_defense = AdversarialDefense(net, loss_fn=loss_fn, optimizer=optimizer) | ||||
@@ -18,13 +18,14 @@ import pytest | |||||
import numpy as np | import numpy as np | ||||
from mindspore import nn | from mindspore import nn | ||||
from mindspore.model_zoo.lenet import LeNet5 | |||||
from mindspore import context | from mindspore import context | ||||
import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
from mindarmour.diff_privacy import DPOptimizerClassFactory | from mindarmour.diff_privacy import DPOptimizerClassFactory | ||||
from mindarmour.diff_privacy import DPModel | from mindarmour.diff_privacy import DPModel | ||||
from test_network import LeNet5 | |||||
def dataset_generator(batch_size, batches): | def dataset_generator(batch_size, batches): | ||||
data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) | data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) | ||||
@@ -21,11 +21,12 @@ import mindspore.nn as nn | |||||
import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
from mindspore.train import Model | from mindspore.train import Model | ||||
import mindspore.context as context | import mindspore.context as context | ||||
from mindspore.model_zoo.lenet import LeNet5 | |||||
from mindarmour.diff_privacy import PrivacyMonitorFactory | from mindarmour.diff_privacy import PrivacyMonitorFactory | ||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from test_network import LeNet5 | |||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
TAG = 'DP-Monitor Test' | TAG = 'DP-Monitor Test' | ||||
@@ -0,0 +1,63 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
from mindspore import nn | |||||
from mindspore.common.initializer import TruncatedNormal | |||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
weight = weight_variable() | |||||
return nn.Conv2d(in_channels, out_channels, | |||||
kernel_size=kernel_size, stride=stride, padding=padding, | |||||
weight_init=weight, has_bias=False, pad_mode="valid") | |||||
def fc_with_initialize(input_channels, out_channels): | |||||
weight = weight_variable() | |||||
bias = weight_variable() | |||||
return nn.Dense(input_channels, out_channels, weight, bias) | |||||
def weight_variable(): | |||||
return TruncatedNormal(0.05) | |||||
class LeNet5(nn.Cell): | |||||
""" | |||||
Lenet network | |||||
""" | |||||
def __init__(self): | |||||
super(LeNet5, self).__init__() | |||||
self.conv1 = conv(1, 6, 5) | |||||
self.conv2 = conv(6, 16, 5) | |||||
self.fc1 = fc_with_initialize(16*5*5, 120) | |||||
self.fc2 = fc_with_initialize(120, 84) | |||||
self.fc3 = fc_with_initialize(84, 10) | |||||
self.relu = nn.ReLU() | |||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
self.flatten = nn.Flatten() | |||||
def construct(self, x): | |||||
x = self.conv1(x) | |||||
x = self.relu(x) | |||||
x = self.max_pool2d(x) | |||||
x = self.conv2(x) | |||||
x = self.relu(x) | |||||
x = self.max_pool2d(x) | |||||
x = self.flatten(x) | |||||
x = self.fc1(x) | |||||
x = self.relu(x) | |||||
x = self.fc2(x) | |||||
x = self.relu(x) | |||||
x = self.fc3(x) | |||||
return x |
@@ -15,11 +15,11 @@ import pytest | |||||
from mindspore import nn | from mindspore import nn | ||||
from mindspore import context | from mindspore import context | ||||
from mindspore.model_zoo.lenet import LeNet5 | |||||
from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
from mindarmour.diff_privacy import DPOptimizerClassFactory | from mindarmour.diff_privacy import DPOptimizerClassFactory | ||||
from test_network import LeNet5 | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||