Browse Source

modify api description of fuzzer and NES

tags/v1.3.0
ZhidanLiu 4 years ago
parent
commit
9d391bbb4a
6 changed files with 146 additions and 207 deletions
  1. +5
    -4
      examples/ai_fuzzer/lenet5_mnist_fuzzing.py
  2. +61
    -97
      mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py
  3. +39
    -78
      mindarmour/fuzz_testing/fuzzing.py
  4. +26
    -22
      mindarmour/fuzz_testing/model_coverage_metrics.py
  5. +7
    -2
      tests/ut/python/fuzzing/test_coverage_metrics.py
  6. +8
    -4
      tests/ut/python/fuzzing/test_fuzzer.py

+ 5
- 4
examples/ai_fuzzer/lenet5_mnist_fuzzing.py View File

@@ -63,9 +63,11 @@ def test_lenet_mnist_fuzzing():
images = data[0].astype(np.float32)
train_images.append(images)
train_images = np.concatenate(train_images, axis=0)
neuron_num = 10
segmented_num = 1000

# initialize fuzz test with training dataset
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images)
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)

# fuzz test with original test data
# get test data
@@ -93,9 +95,8 @@ def test_lenet_mnist_fuzzing():
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())

model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds,
eval_metrics='auto')
model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, eval_metrics='auto')
if metrics:
for key in metrics:
LOGGER.info(TAG, key + ': %s', metrics[key])


+ 61
- 97
mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py View File

@@ -36,48 +36,45 @@ def _bound(image, epislon):

class NES(Attack):
"""
The class is an implementation of the Natural Evolutionary Strategies Attack,
including three settings: Query-Limited setting, Partial-Information setting
and Label-Only setting.
The class is an implementation of the Natural Evolutionary Strategies Attack
Method. NES uses natural evolutionary strategies to estimate gradients to
improve query efficiency. NES covers three settings: Query-Limited setting,
Partial-Information setting and Label-Only setting. In the query-limit
setting, the attack has a limited number of queries to the target model but
access to the probabilities of all classes. In the partial-info setting,
the attack only has access to the probabilities for top-k classes.
In the label-only setting, the attack only has access to a list of k inferred
labels ordered by their predicted probabilities. In the Partial-Information
setting and Label-Only setting, NES do target attack so user need to use
set_target_images method to set target images of target classes.

References: `Andrew Ilyas, Logan Engstrom, Anish Athalye, and Jessy Lin.
Black-box adversarial attacks with limited queries and information. In
ICML, July 2018 <https://arxiv.org/abs/1804.08598>`_

Args:
model (BlackModel): Target model.
scene (str): Scene in 'Label_Only', 'Partial_Info' or
'Query_Limit'.
max_queries (int): Maximum query numbers to generate an adversarial
example. Default: 500000.
top_k (int): For Partial-Info or Label-Only setting, indicating how
much (Top-k) information is available for the attacker. For
Query-Limited setting, this input should be set as -1. Default: -1.
model (BlackModel): Target model to be attacked.
scene (str): Scene in 'Label_Only', 'Partial_Info' or 'Query_Limit'.
max_queries (int): Maximum query numbers to generate an adversarial example. Default: 500000.
top_k (int): For Partial-Info or Label-Only setting, indicating how much (Top-k) information is
available for the attacker. For Query-Limited setting, this input should be set as -1. Default: -1.
num_class (int): Number of classes in dataset. Default: 10.
batch_size (int): Batch size. Default: 96.
epsilon (float): Maximum perturbation allowed in attack. Default: 0.3.
samples_per_draw (int): Number of samples draw in antithetic sampling.
Default: 96.
samples_per_draw (int): Number of samples draw in antithetic sampling. Default: 96.
momentum (float): Momentum. Default: 0.9.
learning_rate (float): Learning rate. Default: 1e-2.
max_lr (float): Max Learning rate. Default: 1e-2.
min_lr (float): Min Learning rate. Default: 5e-5.
sigma (float): Step size of random noise. Default: 1e-3.
plateau_length (int): Length of plateau used in Annealing algorithm.
Default: 20.
plateau_drop (float): Drop of plateau used in Annealing algorithm.
Default: 2.0.
plateau_length (int): Length of plateau used in Annealing algorithm. Default: 20.
plateau_drop (float): Drop of plateau used in Annealing algorithm. Default: 2.0.
adv_thresh (float): Threshold of adversarial. Default: 0.15.
zero_iters (int): Number of points to use for the proxy score.
Default: 10.
starting_eps (float): Starting epsilon used in Label-Only setting.
Default: 1.0.
starting_delta_eps (float): Delta epsilon used in Label-Only setting.
Default: 0.5.
label_only_sigma (float): Sigma used in Label-Only setting.
Default: 1e-3.
conservative (int): Conservation used in epsilon decay, it will
increase if no convergence. Default: 2.
zero_iters (int): Number of points to use for the proxy score. Default: 10.
starting_eps (float): Starting epsilon used in Label-Only setting. Default: 1.0.
starting_delta_eps (float): Delta epsilon used in Label-Only setting. Default: 0.5.
label_only_sigma (float): Sigma used in Label-Only setting. Default: 1e-3.
conservative (int): Conservation used in epsilon decay, it will increase if no convergence. Default: 2.
sparse (bool): If True, input labels are sparse-encoded. If False,
input labels are one-hot-encoded. Default: True.

@@ -94,13 +91,10 @@ class NES(Attack):
>>> tag, adv, queries = nes_instance.generate([initial_img], [target_class])
"""

def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10,
batch_size=128, epsilon=0.3, samples_per_draw=128,
momentum=0.9, learning_rate=1e-3, max_lr=5e-2, min_lr=5e-4,
sigma=1e-3, plateau_length=20, plateau_drop=2.0,
adv_thresh=0.25, zero_iters=10, starting_eps=1.0,
starting_delta_eps=0.5, label_only_sigma=1e-3, conservative=2,
sparse=True):
def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10, batch_size=128, epsilon=0.3,
samples_per_draw=128, momentum=0.9, learning_rate=1e-3, max_lr=5e-2, min_lr=5e-4, sigma=1e-3,
plateau_length=20, plateau_drop=2.0, adv_thresh=0.25, zero_iters=10, starting_eps=1.0,
starting_delta_eps=0.5, label_only_sigma=1e-3, conservative=2, sparse=True):
super(NES, self).__init__()
self._model = check_model('model', model, BlackModel)
self._scene = scene
@@ -108,17 +102,14 @@ class NES(Attack):
self._max_queries = check_int_positive('max_queries', max_queries)
self._num_class = check_int_positive('num_class', num_class)
self._batch_size = check_int_positive('batch_size', batch_size)
self._samples_per_draw = check_int_positive('samples_per_draw',
samples_per_draw)
self._samples_per_draw = check_int_positive('samples_per_draw', samples_per_draw)
self._goal_epsilon = check_value_positive('epsilon', epsilon)
self._momentum = check_value_positive('momentum', momentum)
self._learning_rate = check_value_positive('learning_rate',
learning_rate)
self._learning_rate = check_value_positive('learning_rate', learning_rate)
self._max_lr = check_value_positive('max_lr', max_lr)
self._min_lr = check_value_positive('min_lr', min_lr)
self._sigma = check_value_positive('sigma', sigma)
self._plateau_length = check_int_positive('plateau_length',
plateau_length)
self._plateau_length = check_int_positive('plateau_length', plateau_length)
self._plateau_drop = check_value_positive('plateau_drop', plateau_drop)
# partial information arguments
self._k = top_k
@@ -126,10 +117,8 @@ class NES(Attack):
# label only arguments
self._zero_iters = check_int_positive('zero_iters', zero_iters)
self._starting_eps = check_value_positive('starting_eps', starting_eps)
self._starting_delta_eps = check_value_positive('starting_delta_eps',
starting_delta_eps)
self._label_only_sigma = check_value_positive('label_only_sigma',
label_only_sigma)
self._starting_delta_eps = check_value_positive('starting_delta_eps', starting_delta_eps)
self._label_only_sigma = check_value_positive('label_only_sigma', label_only_sigma)
self._conservative = check_int_positive('conservative', conservative)
self._sparse = check_param_type('sparse', sparse, bool)
self.target_imgs = None
@@ -152,39 +141,32 @@ class NES(Attack):
- numpy.ndarray, query times for each sample.

Raises:
ValueError: If the top_k less than 0 in Label-Only or Partial-Info
setting.
ValueError: If the target_imgs is None in Label-Only or
Partial-Info setting.
ValueError: If scene is not in ['Label_Only', 'Partial_Info',
'Query_Limit']
ValueError: If the top_k less than 0 in Label-Only or Partial-Info setting.
ValueError: If the target_imgs is None in Label-Only or Partial-Info setting.
ValueError: If scene is not in ['Label_Only', 'Partial_Info', 'Query_Limit']

Examples:
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]],
>>> [1, 2])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)

if self._scene == 'Label_Only' or self._scene == 'Partial_Info':
if self._k < 0:
msg = "In 'Label_Only' or 'Partial_Info' mode, " \
"'top_k' must more than 0."
if self._k < 1:
msg = "In 'Label_Only' or 'Partial_Info' mode, 'top_k' must more than 0."
LOGGER.error(TAG, msg)
raise ValueError(msg)
if self.target_imgs is None:
msg = "In 'Label_Only' or 'Partial_Info' mode, " \
"'target_imgs' must be set."
msg = "In 'Label_Only' or 'Partial_Info' mode, 'target_imgs' must be set."
LOGGER.error(TAG, msg)
raise ValueError(msg)

elif self._scene == 'Query_Limit':
self._k = self._num_class
else:
msg = "scene must be string in 'Label_Only', " \
"'Partial_Info' or 'Query_Limit' "
msg = "scene must be string in 'Label_Only', 'Partial_Info' or 'Query_Limit' "
LOGGER.error(TAG, msg)
raise ValueError(msg)

@@ -201,7 +183,7 @@ class NES(Attack):

def set_target_images(self, target_images):
"""
Set target samples for target attack.
Set target samples for target attack in the Partial-Info setting or Label-Only setting.

Args:
target_images (numpy.ndarray): Target samples for target attack.
@@ -253,8 +235,8 @@ class NES(Attack):
return True, adv, num_queries

# antithetic sampling noise
noise_pos = np.random.normal(
size=(self._batch_size // 2,) + origin_image.shape)
size = (self._batch_size // 2,) + origin_image.shape
noise_pos = np.random.normal(size=size)
noise = np.concatenate((noise_pos, -noise_pos), axis=0)
eval_points = adv + self._sigma*noise

@@ -274,8 +256,7 @@ class NES(Attack):
while current_lr >= self._min_lr:
# in partial information only or label only setting
if self._scene == 'Label_Only' or self._scene == 'Partial_Info':
proposed_epsilon = max(self._epsilon - prop_delta_eps,
goal_epsilon)
proposed_epsilon = max(self._epsilon - prop_delta_eps, goal_epsilon)
lower, upper = _bound(origin_image, proposed_epsilon)
proposed_adv = adv - current_lr*np.sign(gradient)
proposed_adv = np.clip(proposed_adv, lower, upper)
@@ -288,23 +269,19 @@ class NES(Attack):
delta_epsilon = max(prop_delta_eps, 0.1)
last_ls = []
adv = proposed_adv
self._epsilon = max(
self._epsilon - prop_delta_eps / self._conservative,
goal_epsilon)
self._epsilon = self._epsilon - prop_delta_eps / self._conservative
self._epsilon = max(self._epsilon, goal_epsilon)
break
elif current_lr >= self._min_lr*2:
current_lr = current_lr / 2
LOGGER.debug(TAG, "backtracking learning rate to %.3f",
current_lr)
LOGGER.debug(TAG, "backtracking learning rate to %.3f", current_lr)
else:
prop_delta_eps = prop_delta_eps / 2
if prop_delta_eps < 2e-3:
LOGGER.debug(TAG, "Did not converge.")
return False, adv, num_queries
current_lr = self._max_lr
LOGGER.debug(TAG,
"backtracking epsilon to %.3f",
self._epsilon - prop_delta_eps)
LOGGER.debug(TAG, "backtracking epsilon to %.3f", self._epsilon - prop_delta_eps)

# update the number of queries
if self._scene == 'Label_Only':
@@ -323,12 +300,10 @@ class NES(Attack):

def _plateau_annealing(self, last_loss):
last_loss = last_loss[-self._plateau_length:]
if last_loss[-1] > last_loss[0] and len(
last_loss) == self._plateau_length:
if last_loss[-1] > last_loss[0] and len(last_loss) == self._plateau_length:
if self._max_lr > self._min_lr:
LOGGER.debug(TAG, "Annealing max learning rate.")
self._max_lr = max(self._max_lr / self._plateau_drop,
self._min_lr)
self._max_lr = max(self._max_lr / self._plateau_drop, self._min_lr)
last_loss = []
return last_loss

@@ -346,8 +321,7 @@ class NES(Attack):
Loss in Query-Limit setting.
"""
LOGGER.debug(TAG, 'enter the function _query_limit_loss().')
loss = self._softmax_cross_entropy_with_logit(
self._model.predict(eval_points))
loss = self._softmax_cross_entropy_with_logit(self._model.predict(eval_points))

return loss, noise

@@ -359,8 +333,7 @@ class NES(Attack):
logit = self._model.predict(eval_points)
loss = np.sort(softmax(logit, axis=1))[:, -self._k:]
inds = np.argsort(logit)[:, -self._k:]
good_loss = np.where(np.equal(inds, self.target_class), loss,
np.zeros(np.shape(inds)))
good_loss = np.where(np.equal(inds, self.target_class), loss, np.zeros(np.shape(inds)))
good_loss = np.max(good_loss, axis=1)
losses = -np.log(good_loss)
return losses, noise
@@ -370,22 +343,16 @@ class NES(Attack):
Loss in Label-Only setting.
"""
LOGGER.debug(TAG, 'enter the function _label_only_loss().')
tiled_points = np.tile(np.expand_dims(eval_points, 0),
[self._zero_iters,
*[1]*len(eval_points.shape)])
noised_eval_im = tiled_points \
+ np.random.randn(self._zero_iters,
self._batch_size,
*origin_image.shape) \
*self._label_only_sigma
noised_eval_im = np.reshape(noised_eval_im, (
self._zero_iters*self._batch_size, *origin_image.shape))
tiled_points = np.tile(np.expand_dims(eval_points, 0), [self._zero_iters, *[1]*len(eval_points.shape)])
noised_eval_im = tiled_points + np.random.randn(self._zero_iters,
self._batch_size,
*origin_image.shape)*self._label_only_sigma
noised_eval_im = np.reshape(noised_eval_im, (self._zero_iters*self._batch_size, *origin_image.shape))
logits = self._model.predict(noised_eval_im)
inds = np.argsort(logits)[:, -self._k:]
real_inds = np.reshape(inds, (self._zero_iters, self._batch_size, -1))
rank_range = np.arange(1, self._k + 1, 1, dtype=np.float32)
tiled_rank_range = np.tile(np.reshape(rank_range, (1, 1, self._k)),
[self._zero_iters, self._batch_size, 1])
tiled_rank_range = np.tile(np.reshape(rank_range, (1, 1, self._k)), [self._zero_iters, self._batch_size, 1])
batches_in = np.where(np.equal(real_inds, self.target_class),
tiled_rank_range,
np.zeros(np.shape(tiled_rank_range)))
@@ -408,16 +375,13 @@ class NES(Attack):
grads = []
for _ in range(self._samples_per_draw // self._batch_size):
if self._scene == 'Label_Only':
loss, np_noise = self._label_only_loss(origin_image,
eval_points,
noise)
loss, np_noise = self._label_only_loss(origin_image, eval_points, noise)
elif self._scene == 'Partial_Info':
loss, np_noise = self._partial_info_loss(eval_points, noise)
else:
loss, np_noise = self._query_limit_loss(eval_points, noise)
# only support three channel images
losses_tiled = np.tile(np.reshape(loss, (-1, 1, 1, 1)),
(1,) + origin_image.shape)
losses_tiled = np.tile(np.reshape(loss, (-1, 1, 1, 1)), (1,) + origin_image.shape)
grad = np.mean(losses_tiled*np_noise, axis=0) / self._sigma

grads.append(grad)


+ 39
- 78
mindarmour/fuzz_testing/fuzzing.py View File

@@ -44,14 +44,15 @@ def _select_next(initial_seeds):


def _coverage_gains(coverages):
""" Calculate the coverage gains of mutated samples. """
""" Calculate the coverage gains of mutated samples."""
gains = [0] + coverages[:-1]
gains = np.array(coverages) - np.array(gains)
return gains


def _is_trans_valid(seed, mutate_sample):
""" Check a mutated sample is valid. If the number of changed pixels in
"""
Check a mutated sample is valid. If the number of changed pixels in
a seed is less than pixels_change_rate*size(seed), this mutate is valid.
Else check the infinite norm of seed changes, if the value of the
infinite norm less than pixel_value_change_rate*255, this mutate is
@@ -80,21 +81,18 @@ def _check_eval_metrics(eval_metrics):
available_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac']
for elem in eval_metrics:
if elem not in available_metrics:
msg = 'metric in list `eval_metrics` must be in {}, but ' \
'got {}.'.format(available_metrics, elem)
msg = 'metric in list `eval_metrics` must be in {}, but got {}.'.format(available_metrics, elem)
LOGGER.error(TAG, msg)
raise ValueError(msg)
eval_metrics_.append(elem.lower())
elif isinstance(eval_metrics, str):
if eval_metrics != 'auto':
msg = "the value of `eval_metrics` must be 'auto' if it's type " \
"is str, but got {}.".format(eval_metrics)
msg = "the value of `eval_metrics` must be 'auto' if it's type is str, but got {}.".format(eval_metrics)
LOGGER.error(TAG, msg)
raise ValueError(msg)
eval_metrics_ = 'auto'
else:
msg = "the type of `eval_metrics` must be str, list or tuple, " \
"but got {}.".format(type(eval_metrics))
msg = "the type of `eval_metrics` must be str, list or tuple, but got {}.".format(type(eval_metrics))
LOGGER.error(TAG, msg)
raise TypeError(msg)
return eval_metrics_
@@ -109,11 +107,9 @@ class Fuzzer:

Args:
target_model (Model): Target fuzz model.
train_dataset (numpy.ndarray): Training dataset used for determining
the neurons' output boundaries.
train_dataset (numpy.ndarray): Training dataset used for determining the neurons' output boundaries.
neuron_num (int): The number of testing neurons.
segmented_num (int): The number of segmented sections of neurons'
output intervals. Default: 1000.
segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 1000.

Examples:
>>> net = Net()
@@ -126,7 +122,9 @@ class Fuzzer:
>>> {'method': 'FGSM',
>>> 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}]
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
>>> neuron_num = 10
>>> segmented_num = 1000
>>> model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num)
>>> samples, labels, preds, strategies, report = model_fuzz_test.fuzz_testing(mutate_config, initial_seeds)
"""

@@ -134,10 +132,7 @@ class Fuzzer:
segmented_num=1000):
self._target_model = check_model('model', target_model, Model)
train_dataset = check_numpy_param('train_dataset', train_dataset)
self._coverage_metrics = ModelCoverageMetrics(target_model,
neuron_num,
segmented_num,
train_dataset)
self._coverage_metrics = ModelCoverageMetrics(target_model, neuron_num, segmented_num, train_dataset)
# Allowed mutate strategies so far.
self._strategies = {'Contrast': Contrast,
'Brightness': Brightness,
@@ -151,30 +146,19 @@ class Fuzzer:
'PGD': ProjectedGradientDescent,
'MDIIM': MomentumDiverseInputIterativeMethod}
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate']
self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur',
'Noise']
self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', 'Noise']
self._attacks_list = ['FGSM', 'PGD', 'MDIIM']
self._attack_param_checklists = {
'FGSM': {'eps': {'dtype': [float],
'range': [0, 1]},
'alpha': {'dtype': [float],
'range': [0, 1]},
'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]},
'alpha': {'dtype': [float], 'range': [0, 1]},
'bounds': {'dtype': [tuple]}},
'PGD': {'eps': {'dtype': [float],
'range': [0, 1]},
'eps_iter': {
'dtype': [float],
'range': [0, 1]},
'nb_iter': {'dtype': [int],
'range': [0, 100000]},
'PGD': {'eps': {'dtype': [float], 'range': [0, 1]},
'eps_iter': {'dtype': [float], 'range': [0, 1]},
'nb_iter': {'dtype': [int], 'range': [0, 100000]},
'bounds': {'dtype': [tuple]}},
'MDIIM': {'eps': {'dtype': [float],
'range': [0, 1]},
'norm_level': {'dtype': [str, int],
'range': [1, 2, '1', '2', 'l1', 'l2',
'inf', 'np.inf']},
'prob': {'dtype': [float],
'range': [0, 1]},
'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]},
'norm_level': {'dtype': [str, int], 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf']},
'prob': {'dtype': [float], 'range': [0, 1]},
'bounds': {'dtype': [tuple]}}}

def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC',
@@ -239,13 +223,11 @@ class Fuzzer:
# Check parameters.
eval_metrics_ = _check_eval_metrics(eval_metrics)
if coverage_metric not in ['KMNC', 'NBC', 'SNAC']:
msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], " \
"but got {}.".format(coverage_metric)
msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], but got {}.".format(coverage_metric)
LOGGER.error(TAG, msg)
raise ValueError(msg)
max_iters = check_int_positive('max_iters', max_iters)
mutate_num_per_seed = check_int_positive('mutate_num_per_seed',
mutate_num_per_seed)
mutate_num_per_seed = check_int_positive('mutate_num_per_seed', mutate_num_per_seed)
mutate_config = self._check_mutate_config(mutate_config)
mutates = self._init_mutates(mutate_config)

@@ -276,39 +258,30 @@ class Fuzzer:
mutate_config,
mutate_num_per_seed)
# Calculate the coverages and predictions of generated samples.
coverages, predicts = self._get_coverages_and_predict(mutate_samples,
coverage_metric)
coverages, predicts = self._get_coverages_and_predict(mutate_samples, coverage_metric)
coverage_gains = _coverage_gains(coverages)
for mutate, cov, pred, strategy in zip(mutate_samples,
coverage_gains,
predicts, mutate_strategies):
for mutate, cov, pred, strategy in zip(mutate_samples, coverage_gains, predicts, mutate_strategies):
fuzz_samples.append(mutate[0])
true_labels.append(mutate[1])
fuzz_preds.append(pred)
fuzz_strategies.append(strategy)
# if the mutate samples has coverage gains add this samples in
# the initial seeds to guide new mutates.
# the initial_seeds to guide new mutates.
if cov > 0:
initial_seeds.append(mutate)
seed, initial_seeds = _select_next(initial_seeds)
iter_num += 1
metrics_report = None
if eval_metrics_ is not None:
metrics_report = self._evaluate(fuzz_samples,
true_labels,
fuzz_preds,
fuzz_strategies,
eval_metrics_)
metrics_report = self._evaluate(fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, eval_metrics_)
return fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, metrics_report

def _get_coverages_and_predict(self, mutate_samples,
coverage_metric="KNMC"):
def _get_coverages_and_predict(self, mutate_samples, coverage_metric="KNMC"):
""" Calculate the coverages and predictions of generated samples."""
samples = [s[0] for s in mutate_samples]
samples = np.array(samples)
coverages = []
predictions = self._target_model.predict(
Tensor(samples.astype(np.float32)))
predictions = self._target_model.predict(Tensor(samples.astype(np.float32)))
predictions = predictions.asnumpy()
for index in range(len(samples)):
mutate = samples[:index + 1]
@@ -349,10 +322,8 @@ class Fuzzer:
mutate_sample = transform.transform(seed[0])
else:
for param_name in selected_param:
transform.__setattr__('_' + str(param_name),
selected_param[param_name])
mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]),
np.array([seed[1]]))[0]
transform.__setattr__('_' + str(param_name), selected_param[param_name])
mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0]
if method not in self._pixel_value_trans_list:
only_pixel_trans = 1
mutate_sample = [mutate_sample, seed[1], only_pixel_trans]
@@ -372,8 +343,7 @@ class Fuzzer:
for config in mutate_config:
check_param_type("config", config, dict)
if set(config.keys()) != {'method', 'params'}:
msg = "The key of each config must be in ('method', 'params'), " \
"but got {}.".format(set(config.keys()))
msg = "The key of each config must be in ('method', 'params'), but got {}.".format(set(config.keys()))
LOGGER.error(TAG, msg)
raise KeyError(msg)

@@ -382,8 +352,7 @@ class Fuzzer:

# Method must be in the optional range.
if method not in self._strategies.keys():
msg = "Config methods must be in {}, but got {}." \
.format(self._strategies.keys(), method)
msg = "Config methods must be in {}, but got {}.".format(self._strategies.keys(), method)
LOGGER.error(TAG, msg)
raise ValueError(msg)

@@ -401,8 +370,7 @@ class Fuzzer:
# Methods in `metate_config` should at least have one in the type of
# pixel value based transform methods.
if not has_pixel_trans:
msg = "mutate methods in mutate_config should at least have one " \
"in {}".format(self._pixel_value_trans_list)
msg = "mutate methods in mutate_config should at least have one in {}".format(self._pixel_value_trans_list)
raise ValueError(msg)

return mutate_config
@@ -424,11 +392,9 @@ class Fuzzer:
'but got its length as{}'.format(len(bounds))
raise ValueError(msg)
for bound_value in bounds:
_ = check_param_multi_types('bound', bound_value,
[int, float])
_ = check_param_multi_types('bound', bound_value, [int, float])
if bounds[0] >= bounds[1]:
msg = "upper bound must more than lower bound, " \
"but upper bound got {}, lower bound " \
msg = "upper bound must more than lower bound, but upper bound got {}, lower bound " \
"got {}".format(bounds[0], bounds[1])
raise ValueError(msg)
elif param_name == 'norm_level':
@@ -437,10 +403,7 @@ class Fuzzer:
allow_type = self._attack_param_checklists[method][param_name]['dtype']
allow_range = self._attack_param_checklists[method][param_name]['range']
_ = check_param_multi_types(str(param_name), param_value, allow_type)
_ = check_param_in_range(str(param_name),
param_value,
allow_range[0],
allow_range[1])
_ = check_param_in_range(str(param_name), param_value, allow_range[0], allow_range[1])

def _init_mutates(self, mutate_config):
""" Check whether the mutate_config meet the specification."""
@@ -454,8 +417,7 @@ class Fuzzer:
loss_fn = self._target_model._loss_fn
if loss_fn is None:
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
mutates[method] = self._strategies[method](network,
loss_fn=loss_fn)
mutates[method] = self._strategies[method](network, loss_fn=loss_fn)
return mutates

def _evaluate(self, fuzz_samples, true_labels, fuzz_preds,
@@ -497,8 +459,7 @@ class Fuzzer:
metrics_report['Attack_success_rate'] = attack_success_rate

if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics:
self._coverage_metrics.calculate_coverage(
np.array(fuzz_samples).astype(np.float32))
self._coverage_metrics.calculate_coverage(np.array(fuzz_samples).astype(np.float32))

if metrics == 'auto' or 'kmnc' in metrics:
kmnc = self._coverage_metrics.get_kmnc()


+ 26
- 22
mindarmour/fuzz_testing/model_coverage_metrics.py View File

@@ -56,7 +56,9 @@ class ModelCoverageMetrics:
>>> train_images = np.random.random((10000, 1, 32, 32)).astype(np.float32)
>>> test_images = np.random.random((5000, 1, 32, 32)).astype(np.float32)
>>> model = Model(net)
>>> model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images)
>>> neuron_num = 10
>>> segmented_num = 1000
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)
>>> model_fuzz_test.calculate_coverage(test_images)
>>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc())
>>> print('NBC of this test is : %s', model_fuzz_test.get_nbc())
@@ -68,16 +70,14 @@ class ModelCoverageMetrics:
self._segmented_num = check_int_positive('segmented_num', segmented_num)
self._neuron_num = check_int_positive('neuron_num', neuron_num)
if self._neuron_num > 1e+9:
msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError' \
'would occur'
msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError would occur'
LOGGER.error(TAG, msg)
raise ValueError(msg)
train_dataset = check_numpy_param('train_dataset', train_dataset)
self._lower_bounds = [np.inf]*self._neuron_num
self._upper_bounds = [-np.inf]*self._neuron_num
self._var = [0]*self._neuron_num
self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in
range(self._neuron_num)]
self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in range(self._neuron_num)]
self._lower_corner_hits = [0]*self._neuron_num
self._upper_corner_hits = [0]*self._neuron_num
self._bounds_get(train_dataset)
@@ -99,19 +99,16 @@ class ModelCoverageMetrics:
inputs = train_dataset[i*batch_size: (i + 1)*batch_size]
output = self._model.predict(Tensor(inputs)).asnumpy()
output_mat.append(output)
lower_compare_array = np.concatenate(
[output, np.array([self._lower_bounds])], axis=0)
lower_compare_array = np.concatenate([output, np.array([self._lower_bounds])], axis=0)
self._lower_bounds = np.min(lower_compare_array, axis=0)
upper_compare_array = np.concatenate(
[output, np.array([self._upper_bounds])], axis=0)
upper_compare_array = np.concatenate([output, np.array([self._upper_bounds])], axis=0)
self._upper_bounds = np.max(upper_compare_array, axis=0)
if batches == 0:
output = self._model.predict(Tensor(train_dataset)).asnumpy()
self._lower_bounds = np.min(output, axis=0)
self._upper_bounds = np.max(output, axis=0)
output_mat.append(output)
self._var = np.std(np.concatenate(np.array(output_mat), axis=0),
axis=0)
self._var = np.std(np.concatenate(np.array(output_mat), axis=0), axis=0)

def _sections_hits_count(self, dataset, intervals):
"""
@@ -119,8 +116,7 @@ class ModelCoverageMetrics:

Args:
dataset (numpy.ndarray): Testing data.
intervals (list[float]): Segmentation intervals of neurons'
outputs.
intervals (list[float]): Segmentation intervals of neurons' outputs.
"""
dataset = check_numpy_param('dataset', dataset)
batch_output = self._model.predict(Tensor(dataset)).asnumpy()
@@ -142,11 +138,12 @@ class ModelCoverageMetrics:
dataset (numpy.ndarray): Data for fuzz test.
bias_coefficient (Union[int, float]): The coefficient used
for changing the neurons' output boundaries. Default: 0.
batch_size (int): The number of samples in a predict batch.
Default: 32.
batch_size (int): The number of samples in a predict batch. Default: 32.

Examples:
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images)
>>> neuron_num = 10
>>> segmented_num = 1000
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)
>>> model_fuzz_test.calculate_coverage(test_images)
"""

@@ -158,12 +155,13 @@ class ModelCoverageMetrics:
intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num
batches = dataset.shape[0] // batch_size
for i in range(batches):
self._sections_hits_count(
dataset[i*batch_size: (i + 1)*batch_size], intervals)
self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals)

def get_kmnc(self):
"""
Get the metric of 'k-multisection neuron coverage'.
Get the metric of 'k-multisection neuron coverage'. KMNC measures how
thoroughly the given set of test inputs covers the range of neurons
output values derived from training dataset.

Returns:
float, the metric of 'k-multisection neuron coverage'.
@@ -176,7 +174,11 @@ class ModelCoverageMetrics:

def get_nbc(self):
"""
Get the metric of 'neuron boundary coverage'.
Get the metric of 'neuron boundary coverage' :math` NBC =(|UpperCornerNeuron|
+ |LowerCornerNeuron|)/(2*|N|)`, where :math`|N|` is the number of neurons,
NBC refers to the proportion of neurons whose neurons output value in
the test dataset exceeds the upper and lower bounds of the corresponding
neurons output value in the training dataset.

Returns:
float, the metric of 'neuron boundary coverage'.
@@ -184,13 +186,15 @@ class ModelCoverageMetrics:
Examples:
>>> model_fuzz_test.get_nbc()
"""
nbc = (np.sum(self._lower_corner_hits) + np.sum(
self._upper_corner_hits)) / (2*self._neuron_num)
nbc = (np.sum(self._lower_corner_hits) + np.sum(self._upper_corner_hits)) / (2*self._neuron_num)
return nbc

def get_snac(self):
"""
Get the metric of 'strong neuron activation coverage'.
:math` SNAC =|UpperCornerNeuron| / |N|`. SNAC refers to the proportion
of neurons whose neurons output value in the test set exceeds the upper
bounds of the corresponding neurons output value in the training set.

Returns:
float, the metric of 'strong neuron activation coverage'.


+ 7
- 2
tests/ut/python/fuzzing/test_coverage_metrics.py View File

@@ -69,8 +69,10 @@ def test_lenet_mnist_coverage_cpu():
model = Model(net)

# initialize fuzz test with training dataset
neuron_num = 10
segmented_num = 1000
training_data = (np.random.random((10000, 10))*20).astype(np.float32)
model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data)
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data)

# fuzz test with original test data
# get test data
@@ -103,8 +105,11 @@ def test_lenet_mnist_coverage_ascend():
model = Model(net)

# initialize fuzz test with training dataset
neuron_num = 10
segmented_num = 1000
training_data = (np.random.random((10000, 10))*20).astype(np.float32)
model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data)

model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data,)

# fuzz test with original test data
# get test data


+ 8
- 4
tests/ut/python/fuzzing/test_fuzzer.py View File

@@ -101,8 +101,10 @@ def test_fuzzing_ascend():
'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}
]
# initialize fuzz test with training dataset
neuron_num = 10
segmented_num = 1000
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images)
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)

# fuzz test with original test data
# get test data
@@ -121,7 +123,7 @@ def test_fuzzing_ascend():
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())

model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds)
print(metrics)

@@ -137,6 +139,8 @@ def test_fuzzing_cpu():
model = Model(net)
batch_size = 8
num_classe = 10
neuron_num = 10
segmented_num = 1000
mutate_config = [{'method': 'Blur',
'params': {'auto_param': [True]}},
{'method': 'Contrast',
@@ -148,7 +152,7 @@ def test_fuzzing_cpu():
]
# initialize fuzz test with training dataset
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images)
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)

# fuzz test with original test data
# get test data
@@ -167,6 +171,6 @@ def test_fuzzing_cpu():
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())

model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds)
print(metrics)

Loading…
Cancel
Save