Browse Source

!46 fix review bugs in fuzzing and mechanism

Merge pull request !46 from ZhidanLiu/master
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
98b88e5575
3 changed files with 32 additions and 28 deletions
  1. +1
    -1
      example/mnist_demo/lenet5_mnist_fuzzing.py
  2. +16
    -13
      mindarmour/diff_privacy/mechanisms/mechanisms.py
  3. +15
    -14
      mindarmour/fuzzing/fuzzing.py

+ 1
- 1
example/mnist_demo/lenet5_mnist_fuzzing.py View File

@@ -70,7 +70,7 @@ def test_lenet_mnist_fuzzing():

# make initial seeds
for img, label in zip(test_images, test_labels):
initial_seeds.append([img, label, 0])
initial_seeds.append([img, label])

initial_seeds = initial_seeds[:100]
model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32))


+ 16
- 13
mindarmour/diff_privacy/mechanisms/mechanisms.py View File

@@ -14,6 +14,8 @@
"""
Noise Mechanisms.
"""
from abc import abstractmethod

from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.ops import operations as P
@@ -23,8 +25,11 @@ from mindspore.common import dtype as mstype

from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_value_non_negative
from mindarmour.utils._check_param import check_param_in_range
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Defense'


class MechanismsFactory:
@@ -99,6 +104,7 @@ class Mechanisms(Cell):
Basic class of noise generated mechanism.
"""

@abstractmethod
def construct(self, gradients):
"""
Construct function.
@@ -115,8 +121,9 @@ class GaussianRandom(Mechanisms):
initial_noise_multiplier(float): Ratio of the standard deviation of
Gaussian noise divided by the norm_bound, which will be used to
calculate privacy spent. Default: 1.5.
mean(float): Average value of random noise. Default: 0.0.
seed(int): Original random seed. Default: 0.
seed(int): Original random seed, if seed=0 random normal will use secure
random number. IF seed!=0 random normal will generate values using
given seed. Default: 0.

Returns:
Tensor, generated noise with shape like given gradients.
@@ -130,16 +137,14 @@ class GaussianRandom(Mechanisms):
>>> print(res)
"""

def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, mean=0.0, seed=0):
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0):
super(GaussianRandom, self).__init__()
self._norm_bound = check_value_positive('norm_bound', norm_bound)
self._norm_bound = Tensor(norm_bound, mstype.float32)
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier',
initial_noise_multiplier)
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32)
mean = check_param_type('mean', mean, float)
mean = check_value_non_negative('mean', mean)
self._mean = Tensor(mean, mstype.float32)
self._mean = Tensor(0, mstype.float32)
self._normal = P.Normal(seed=seed)

def construct(self, gradients):
@@ -160,8 +165,8 @@ class GaussianRandom(Mechanisms):

class AdaGaussianRandom(Mechanisms):
"""
Adaptive Gaussian noise generated mechanism. Noise would be decayed with training. Decay mode could be 'Time'
mode or 'Step' mode.
Adaptive Gaussian noise generated mechanism. Noise would be decayed with
training. Decay mode could be 'Time' mode or 'Step' mode.

Args:
norm_bound(float): Clipping bound for the l2 norm of the gradients.
@@ -192,7 +197,7 @@ class AdaGaussianRandom(Mechanisms):
>>> print(res)
"""

def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5, mean=0.0,
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5,
noise_decay_rate=6e-4, decay_policy='Time', seed=0):
super(AdaGaussianRandom, self).__init__()
norm_bound = check_value_positive('norm_bound', norm_bound)
@@ -205,9 +210,7 @@ class AdaGaussianRandom(Mechanisms):
name='initial_noise_multiplier')
self._noise_multiplier = Parameter(initial_noise_multiplier,
name='noise_multiplier')
mean = check_param_type('mean', mean, float)
mean = check_value_non_negative('mean', mean)
self._mean = Tensor(mean, mstype.float32)
self._mean = Tensor(0, mstype.float32)
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float)
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0)
self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32)


+ 15
- 14
mindarmour/fuzzing/fuzzing.py View File

@@ -35,10 +35,10 @@ class Fuzzing:
Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_

Args:
initial_seeds (list): Initial fuzzing seed, format: [[image, label, 0],
[image, label, 0], ...].
initial_seeds (list): Initial fuzzing seed, format: [[image, label],
[image, label], ...].
target_model (Model): Target fuzz model.
train_dataset (numpy.ndarray): Training dataset used for determine
train_dataset (numpy.ndarray): Training dataset used for determining
the neurons' output boundaries.
const_k (int): The number of mutate tests for a seed.
mode (str): Image mode used in image transform, 'L' means grey graph.
@@ -68,8 +68,8 @@ class Fuzzing:
seed = seed[0]
info = [seed, seed]
mutate_tests = []
affine_trans = ['Contrast', 'Brightness', 'Blur', 'Noise']
pixel_value_trans = ['Translate', 'Scale', 'Shear', 'Rotate']
pixel_value_trans = ['Contrast', 'Brightness', 'Blur', 'Noise']
affine_trans = ['Translate', 'Scale', 'Shear', 'Rotate']
strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur,
'Noise': Noise,
'Translate': Translate, 'Scale': Scale, 'Shear': Shear,
@@ -80,7 +80,8 @@ class Fuzzing:
trans_strage = self._random_pick_mutate(affine_trans,
pixel_value_trans)
else:
trans_strage = self._random_pick_mutate(affine_trans, [])
trans_strage = self._random_pick_mutate(pixel_value_trans,
[])
transform = strages[trans_strage](
self._image_value_expand(seed), self.mode)
transform.random_param()
@@ -105,21 +106,21 @@ class Fuzzing:
Default: 'KMNC'.

Returns:
list, mutated tests mis-predicted by target dnn model.
list, mutated tests mis-predicted by target DNN model.
"""
seed = self._select_next()
failed_tests = []
seed_num = 0
while seed and seed_num < self.max_seed_num:
mutate_tests = self._metamorphic_mutate(seed[0])
coverages, results = self._run(mutate_tests, coverage_metric)
coverages, predicts = self._run(mutate_tests, coverage_metric)
coverage_gains = self._coverage_gains(coverages)
for mutate, cov, res in zip(mutate_tests, coverage_gains, results):
for mutate, cov, res in zip(mutate_tests, coverage_gains, predicts):
if np.argmax(seed[1]) != np.argmax(res):
failed_tests.append(mutate)
continue
if cov > 0:
self.initial_seeds.append([mutate, seed[1], 0])
self.initial_seeds.append([mutate, seed[1]])
seed = self._select_next()
seed_num += 1

@@ -154,17 +155,17 @@ class Fuzzing:

def _is_trans_valid(self, seed, mutate_test):
is_valid = False
alpha = 0.02
beta = 0.2
pixels_change_rate = 0.02
pixel_value_change_rate = 0.2
diff = np.array(seed - mutate_test).flatten()
size = np.shape(diff)[0]
l0 = np.linalg.norm(diff, ord=0)
linf = np.linalg.norm(diff, ord=np.inf)
if l0 > alpha*size:
if l0 > pixels_change_rate*size:
if linf < 256:
is_valid = True
else:
if linf < beta*255:
if linf < pixel_value_change_rate*255:
is_valid = True

return is_valid

Loading…
Cancel
Save