@@ -63,9 +63,11 @@ def test_lenet_mnist_fuzzing(): | |||
images = data[0].astype(np.float32) | |||
train_images.append(images) | |||
train_images = np.concatenate(train_images, axis=0) | |||
neuron_num = 10 | |||
segmented_num = 1000 | |||
# initialize fuzz test with training dataset | |||
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
# fuzz test with original test data | |||
# get test data | |||
@@ -93,9 +95,8 @@ def test_lenet_mnist_fuzzing(): | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzer(model, train_images, 10, 1000) | |||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, | |||
eval_metrics='auto') | |||
model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) | |||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, eval_metrics='auto') | |||
if metrics: | |||
for key in metrics: | |||
LOGGER.info(TAG, key + ': %s', metrics[key]) | |||
@@ -36,48 +36,45 @@ def _bound(image, epislon): | |||
class NES(Attack): | |||
""" | |||
The class is an implementation of the Natural Evolutionary Strategies Attack, | |||
including three settings: Query-Limited setting, Partial-Information setting | |||
and Label-Only setting. | |||
The class is an implementation of the Natural Evolutionary Strategies Attack | |||
Method. NES uses natural evolutionary strategies to estimate gradients to | |||
improve query efficiency. NES covers three settings: Query-Limited setting, | |||
Partial-Information setting and Label-Only setting. In the query-limit | |||
setting, the attack has a limited number of queries to the target model but | |||
access to the probabilities of all classes. In the partial-info setting, | |||
the attack only has access to the probabilities for top-k classes. | |||
In the label-only setting, the attack only has access to a list of k inferred | |||
labels ordered by their predicted probabilities. In the Partial-Information | |||
setting and Label-Only setting, NES do target attack so user need to use | |||
set_target_images method to set target images of target classes. | |||
References: `Andrew Ilyas, Logan Engstrom, Anish Athalye, and Jessy Lin. | |||
Black-box adversarial attacks with limited queries and information. In | |||
ICML, July 2018 <https://arxiv.org/abs/1804.08598>`_ | |||
Args: | |||
model (BlackModel): Target model. | |||
scene (str): Scene in 'Label_Only', 'Partial_Info' or | |||
'Query_Limit'. | |||
max_queries (int): Maximum query numbers to generate an adversarial | |||
example. Default: 500000. | |||
top_k (int): For Partial-Info or Label-Only setting, indicating how | |||
much (Top-k) information is available for the attacker. For | |||
Query-Limited setting, this input should be set as -1. Default: -1. | |||
model (BlackModel): Target model to be attacked. | |||
scene (str): Scene in 'Label_Only', 'Partial_Info' or 'Query_Limit'. | |||
max_queries (int): Maximum query numbers to generate an adversarial example. Default: 500000. | |||
top_k (int): For Partial-Info or Label-Only setting, indicating how much (Top-k) information is | |||
available for the attacker. For Query-Limited setting, this input should be set as -1. Default: -1. | |||
num_class (int): Number of classes in dataset. Default: 10. | |||
batch_size (int): Batch size. Default: 96. | |||
epsilon (float): Maximum perturbation allowed in attack. Default: 0.3. | |||
samples_per_draw (int): Number of samples draw in antithetic sampling. | |||
Default: 96. | |||
samples_per_draw (int): Number of samples draw in antithetic sampling. Default: 96. | |||
momentum (float): Momentum. Default: 0.9. | |||
learning_rate (float): Learning rate. Default: 1e-2. | |||
max_lr (float): Max Learning rate. Default: 1e-2. | |||
min_lr (float): Min Learning rate. Default: 5e-5. | |||
sigma (float): Step size of random noise. Default: 1e-3. | |||
plateau_length (int): Length of plateau used in Annealing algorithm. | |||
Default: 20. | |||
plateau_drop (float): Drop of plateau used in Annealing algorithm. | |||
Default: 2.0. | |||
plateau_length (int): Length of plateau used in Annealing algorithm. Default: 20. | |||
plateau_drop (float): Drop of plateau used in Annealing algorithm. Default: 2.0. | |||
adv_thresh (float): Threshold of adversarial. Default: 0.15. | |||
zero_iters (int): Number of points to use for the proxy score. | |||
Default: 10. | |||
starting_eps (float): Starting epsilon used in Label-Only setting. | |||
Default: 1.0. | |||
starting_delta_eps (float): Delta epsilon used in Label-Only setting. | |||
Default: 0.5. | |||
label_only_sigma (float): Sigma used in Label-Only setting. | |||
Default: 1e-3. | |||
conservative (int): Conservation used in epsilon decay, it will | |||
increase if no convergence. Default: 2. | |||
zero_iters (int): Number of points to use for the proxy score. Default: 10. | |||
starting_eps (float): Starting epsilon used in Label-Only setting. Default: 1.0. | |||
starting_delta_eps (float): Delta epsilon used in Label-Only setting. Default: 0.5. | |||
label_only_sigma (float): Sigma used in Label-Only setting. Default: 1e-3. | |||
conservative (int): Conservation used in epsilon decay, it will increase if no convergence. Default: 2. | |||
sparse (bool): If True, input labels are sparse-encoded. If False, | |||
input labels are one-hot-encoded. Default: True. | |||
@@ -94,13 +91,10 @@ class NES(Attack): | |||
>>> tag, adv, queries = nes_instance.generate([initial_img], [target_class]) | |||
""" | |||
def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10, | |||
batch_size=128, epsilon=0.3, samples_per_draw=128, | |||
momentum=0.9, learning_rate=1e-3, max_lr=5e-2, min_lr=5e-4, | |||
sigma=1e-3, plateau_length=20, plateau_drop=2.0, | |||
adv_thresh=0.25, zero_iters=10, starting_eps=1.0, | |||
starting_delta_eps=0.5, label_only_sigma=1e-3, conservative=2, | |||
sparse=True): | |||
def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10, batch_size=128, epsilon=0.3, | |||
samples_per_draw=128, momentum=0.9, learning_rate=1e-3, max_lr=5e-2, min_lr=5e-4, sigma=1e-3, | |||
plateau_length=20, plateau_drop=2.0, adv_thresh=0.25, zero_iters=10, starting_eps=1.0, | |||
starting_delta_eps=0.5, label_only_sigma=1e-3, conservative=2, sparse=True): | |||
super(NES, self).__init__() | |||
self._model = check_model('model', model, BlackModel) | |||
self._scene = scene | |||
@@ -108,17 +102,14 @@ class NES(Attack): | |||
self._max_queries = check_int_positive('max_queries', max_queries) | |||
self._num_class = check_int_positive('num_class', num_class) | |||
self._batch_size = check_int_positive('batch_size', batch_size) | |||
self._samples_per_draw = check_int_positive('samples_per_draw', | |||
samples_per_draw) | |||
self._samples_per_draw = check_int_positive('samples_per_draw', samples_per_draw) | |||
self._goal_epsilon = check_value_positive('epsilon', epsilon) | |||
self._momentum = check_value_positive('momentum', momentum) | |||
self._learning_rate = check_value_positive('learning_rate', | |||
learning_rate) | |||
self._learning_rate = check_value_positive('learning_rate', learning_rate) | |||
self._max_lr = check_value_positive('max_lr', max_lr) | |||
self._min_lr = check_value_positive('min_lr', min_lr) | |||
self._sigma = check_value_positive('sigma', sigma) | |||
self._plateau_length = check_int_positive('plateau_length', | |||
plateau_length) | |||
self._plateau_length = check_int_positive('plateau_length', plateau_length) | |||
self._plateau_drop = check_value_positive('plateau_drop', plateau_drop) | |||
# partial information arguments | |||
self._k = top_k | |||
@@ -126,10 +117,8 @@ class NES(Attack): | |||
# label only arguments | |||
self._zero_iters = check_int_positive('zero_iters', zero_iters) | |||
self._starting_eps = check_value_positive('starting_eps', starting_eps) | |||
self._starting_delta_eps = check_value_positive('starting_delta_eps', | |||
starting_delta_eps) | |||
self._label_only_sigma = check_value_positive('label_only_sigma', | |||
label_only_sigma) | |||
self._starting_delta_eps = check_value_positive('starting_delta_eps', starting_delta_eps) | |||
self._label_only_sigma = check_value_positive('label_only_sigma', label_only_sigma) | |||
self._conservative = check_int_positive('conservative', conservative) | |||
self._sparse = check_param_type('sparse', sparse, bool) | |||
self.target_imgs = None | |||
@@ -152,39 +141,32 @@ class NES(Attack): | |||
- numpy.ndarray, query times for each sample. | |||
Raises: | |||
ValueError: If the top_k less than 0 in Label-Only or Partial-Info | |||
setting. | |||
ValueError: If the target_imgs is None in Label-Only or | |||
Partial-Info setting. | |||
ValueError: If scene is not in ['Label_Only', 'Partial_Info', | |||
'Query_Limit'] | |||
ValueError: If the top_k less than 0 in Label-Only or Partial-Info setting. | |||
ValueError: If the target_imgs is None in Label-Only or Partial-Info setting. | |||
ValueError: If scene is not in ['Label_Only', 'Partial_Info', 'Query_Limit'] | |||
Examples: | |||
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]], | |||
>>> [1, 2]) | |||
""" | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) | |||
if not self._sparse: | |||
labels = np.argmax(labels, axis=1) | |||
if self._scene == 'Label_Only' or self._scene == 'Partial_Info': | |||
if self._k < 0: | |||
msg = "In 'Label_Only' or 'Partial_Info' mode, " \ | |||
"'top_k' must more than 0." | |||
if self._k < 1: | |||
msg = "In 'Label_Only' or 'Partial_Info' mode, 'top_k' must more than 0." | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
if self.target_imgs is None: | |||
msg = "In 'Label_Only' or 'Partial_Info' mode, " \ | |||
"'target_imgs' must be set." | |||
msg = "In 'Label_Only' or 'Partial_Info' mode, 'target_imgs' must be set." | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
elif self._scene == 'Query_Limit': | |||
self._k = self._num_class | |||
else: | |||
msg = "scene must be string in 'Label_Only', " \ | |||
"'Partial_Info' or 'Query_Limit' " | |||
msg = "scene must be string in 'Label_Only', 'Partial_Info' or 'Query_Limit' " | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
@@ -201,7 +183,7 @@ class NES(Attack): | |||
def set_target_images(self, target_images): | |||
""" | |||
Set target samples for target attack. | |||
Set target samples for target attack in the Partial-Info setting or Label-Only setting. | |||
Args: | |||
target_images (numpy.ndarray): Target samples for target attack. | |||
@@ -253,8 +235,8 @@ class NES(Attack): | |||
return True, adv, num_queries | |||
# antithetic sampling noise | |||
noise_pos = np.random.normal( | |||
size=(self._batch_size // 2,) + origin_image.shape) | |||
size = (self._batch_size // 2,) + origin_image.shape | |||
noise_pos = np.random.normal(size=size) | |||
noise = np.concatenate((noise_pos, -noise_pos), axis=0) | |||
eval_points = adv + self._sigma*noise | |||
@@ -274,8 +256,7 @@ class NES(Attack): | |||
while current_lr >= self._min_lr: | |||
# in partial information only or label only setting | |||
if self._scene == 'Label_Only' or self._scene == 'Partial_Info': | |||
proposed_epsilon = max(self._epsilon - prop_delta_eps, | |||
goal_epsilon) | |||
proposed_epsilon = max(self._epsilon - prop_delta_eps, goal_epsilon) | |||
lower, upper = _bound(origin_image, proposed_epsilon) | |||
proposed_adv = adv - current_lr*np.sign(gradient) | |||
proposed_adv = np.clip(proposed_adv, lower, upper) | |||
@@ -288,23 +269,19 @@ class NES(Attack): | |||
delta_epsilon = max(prop_delta_eps, 0.1) | |||
last_ls = [] | |||
adv = proposed_adv | |||
self._epsilon = max( | |||
self._epsilon - prop_delta_eps / self._conservative, | |||
goal_epsilon) | |||
self._epsilon = self._epsilon - prop_delta_eps / self._conservative | |||
self._epsilon = max(self._epsilon, goal_epsilon) | |||
break | |||
elif current_lr >= self._min_lr*2: | |||
current_lr = current_lr / 2 | |||
LOGGER.debug(TAG, "backtracking learning rate to %.3f", | |||
current_lr) | |||
LOGGER.debug(TAG, "backtracking learning rate to %.3f", current_lr) | |||
else: | |||
prop_delta_eps = prop_delta_eps / 2 | |||
if prop_delta_eps < 2e-3: | |||
LOGGER.debug(TAG, "Did not converge.") | |||
return False, adv, num_queries | |||
current_lr = self._max_lr | |||
LOGGER.debug(TAG, | |||
"backtracking epsilon to %.3f", | |||
self._epsilon - prop_delta_eps) | |||
LOGGER.debug(TAG, "backtracking epsilon to %.3f", self._epsilon - prop_delta_eps) | |||
# update the number of queries | |||
if self._scene == 'Label_Only': | |||
@@ -323,12 +300,10 @@ class NES(Attack): | |||
def _plateau_annealing(self, last_loss): | |||
last_loss = last_loss[-self._plateau_length:] | |||
if last_loss[-1] > last_loss[0] and len( | |||
last_loss) == self._plateau_length: | |||
if last_loss[-1] > last_loss[0] and len(last_loss) == self._plateau_length: | |||
if self._max_lr > self._min_lr: | |||
LOGGER.debug(TAG, "Annealing max learning rate.") | |||
self._max_lr = max(self._max_lr / self._plateau_drop, | |||
self._min_lr) | |||
self._max_lr = max(self._max_lr / self._plateau_drop, self._min_lr) | |||
last_loss = [] | |||
return last_loss | |||
@@ -346,8 +321,7 @@ class NES(Attack): | |||
Loss in Query-Limit setting. | |||
""" | |||
LOGGER.debug(TAG, 'enter the function _query_limit_loss().') | |||
loss = self._softmax_cross_entropy_with_logit( | |||
self._model.predict(eval_points)) | |||
loss = self._softmax_cross_entropy_with_logit(self._model.predict(eval_points)) | |||
return loss, noise | |||
@@ -359,8 +333,7 @@ class NES(Attack): | |||
logit = self._model.predict(eval_points) | |||
loss = np.sort(softmax(logit, axis=1))[:, -self._k:] | |||
inds = np.argsort(logit)[:, -self._k:] | |||
good_loss = np.where(np.equal(inds, self.target_class), loss, | |||
np.zeros(np.shape(inds))) | |||
good_loss = np.where(np.equal(inds, self.target_class), loss, np.zeros(np.shape(inds))) | |||
good_loss = np.max(good_loss, axis=1) | |||
losses = -np.log(good_loss) | |||
return losses, noise | |||
@@ -370,22 +343,16 @@ class NES(Attack): | |||
Loss in Label-Only setting. | |||
""" | |||
LOGGER.debug(TAG, 'enter the function _label_only_loss().') | |||
tiled_points = np.tile(np.expand_dims(eval_points, 0), | |||
[self._zero_iters, | |||
*[1]*len(eval_points.shape)]) | |||
noised_eval_im = tiled_points \ | |||
+ np.random.randn(self._zero_iters, | |||
self._batch_size, | |||
*origin_image.shape) \ | |||
*self._label_only_sigma | |||
noised_eval_im = np.reshape(noised_eval_im, ( | |||
self._zero_iters*self._batch_size, *origin_image.shape)) | |||
tiled_points = np.tile(np.expand_dims(eval_points, 0), [self._zero_iters, *[1]*len(eval_points.shape)]) | |||
noised_eval_im = tiled_points + np.random.randn(self._zero_iters, | |||
self._batch_size, | |||
*origin_image.shape)*self._label_only_sigma | |||
noised_eval_im = np.reshape(noised_eval_im, (self._zero_iters*self._batch_size, *origin_image.shape)) | |||
logits = self._model.predict(noised_eval_im) | |||
inds = np.argsort(logits)[:, -self._k:] | |||
real_inds = np.reshape(inds, (self._zero_iters, self._batch_size, -1)) | |||
rank_range = np.arange(1, self._k + 1, 1, dtype=np.float32) | |||
tiled_rank_range = np.tile(np.reshape(rank_range, (1, 1, self._k)), | |||
[self._zero_iters, self._batch_size, 1]) | |||
tiled_rank_range = np.tile(np.reshape(rank_range, (1, 1, self._k)), [self._zero_iters, self._batch_size, 1]) | |||
batches_in = np.where(np.equal(real_inds, self.target_class), | |||
tiled_rank_range, | |||
np.zeros(np.shape(tiled_rank_range))) | |||
@@ -408,16 +375,13 @@ class NES(Attack): | |||
grads = [] | |||
for _ in range(self._samples_per_draw // self._batch_size): | |||
if self._scene == 'Label_Only': | |||
loss, np_noise = self._label_only_loss(origin_image, | |||
eval_points, | |||
noise) | |||
loss, np_noise = self._label_only_loss(origin_image, eval_points, noise) | |||
elif self._scene == 'Partial_Info': | |||
loss, np_noise = self._partial_info_loss(eval_points, noise) | |||
else: | |||
loss, np_noise = self._query_limit_loss(eval_points, noise) | |||
# only support three channel images | |||
losses_tiled = np.tile(np.reshape(loss, (-1, 1, 1, 1)), | |||
(1,) + origin_image.shape) | |||
losses_tiled = np.tile(np.reshape(loss, (-1, 1, 1, 1)), (1,) + origin_image.shape) | |||
grad = np.mean(losses_tiled*np_noise, axis=0) / self._sigma | |||
grads.append(grad) | |||
@@ -44,14 +44,15 @@ def _select_next(initial_seeds): | |||
def _coverage_gains(coverages): | |||
""" Calculate the coverage gains of mutated samples. """ | |||
""" Calculate the coverage gains of mutated samples.""" | |||
gains = [0] + coverages[:-1] | |||
gains = np.array(coverages) - np.array(gains) | |||
return gains | |||
def _is_trans_valid(seed, mutate_sample): | |||
""" Check a mutated sample is valid. If the number of changed pixels in | |||
""" | |||
Check a mutated sample is valid. If the number of changed pixels in | |||
a seed is less than pixels_change_rate*size(seed), this mutate is valid. | |||
Else check the infinite norm of seed changes, if the value of the | |||
infinite norm less than pixel_value_change_rate*255, this mutate is | |||
@@ -80,21 +81,18 @@ def _check_eval_metrics(eval_metrics): | |||
available_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] | |||
for elem in eval_metrics: | |||
if elem not in available_metrics: | |||
msg = 'metric in list `eval_metrics` must be in {}, but ' \ | |||
'got {}.'.format(available_metrics, elem) | |||
msg = 'metric in list `eval_metrics` must be in {}, but got {}.'.format(available_metrics, elem) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
eval_metrics_.append(elem.lower()) | |||
elif isinstance(eval_metrics, str): | |||
if eval_metrics != 'auto': | |||
msg = "the value of `eval_metrics` must be 'auto' if it's type " \ | |||
"is str, but got {}.".format(eval_metrics) | |||
msg = "the value of `eval_metrics` must be 'auto' if it's type is str, but got {}.".format(eval_metrics) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
eval_metrics_ = 'auto' | |||
else: | |||
msg = "the type of `eval_metrics` must be str, list or tuple, " \ | |||
"but got {}.".format(type(eval_metrics)) | |||
msg = "the type of `eval_metrics` must be str, list or tuple, but got {}.".format(type(eval_metrics)) | |||
LOGGER.error(TAG, msg) | |||
raise TypeError(msg) | |||
return eval_metrics_ | |||
@@ -109,11 +107,9 @@ class Fuzzer: | |||
Args: | |||
target_model (Model): Target fuzz model. | |||
train_dataset (numpy.ndarray): Training dataset used for determining | |||
the neurons' output boundaries. | |||
train_dataset (numpy.ndarray): Training dataset used for determining the neurons' output boundaries. | |||
neuron_num (int): The number of testing neurons. | |||
segmented_num (int): The number of segmented sections of neurons' | |||
output intervals. Default: 1000. | |||
segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 1000. | |||
Examples: | |||
>>> net = Net() | |||
@@ -126,7 +122,9 @@ class Fuzzer: | |||
>>> {'method': 'FGSM', | |||
>>> 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}] | |||
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
>>> model_fuzz_test = Fuzzer(model, train_images, 10, 1000) | |||
>>> neuron_num = 10 | |||
>>> segmented_num = 1000 | |||
>>> model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) | |||
>>> samples, labels, preds, strategies, report = model_fuzz_test.fuzz_testing(mutate_config, initial_seeds) | |||
""" | |||
@@ -134,10 +132,7 @@ class Fuzzer: | |||
segmented_num=1000): | |||
self._target_model = check_model('model', target_model, Model) | |||
train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
self._coverage_metrics = ModelCoverageMetrics(target_model, | |||
neuron_num, | |||
segmented_num, | |||
train_dataset) | |||
self._coverage_metrics = ModelCoverageMetrics(target_model, neuron_num, segmented_num, train_dataset) | |||
# Allowed mutate strategies so far. | |||
self._strategies = {'Contrast': Contrast, | |||
'Brightness': Brightness, | |||
@@ -151,30 +146,19 @@ class Fuzzer: | |||
'PGD': ProjectedGradientDescent, | |||
'MDIIM': MomentumDiverseInputIterativeMethod} | |||
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||
self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', | |||
'Noise'] | |||
self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', 'Noise'] | |||
self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] | |||
self._attack_param_checklists = { | |||
'FGSM': {'eps': {'dtype': [float], | |||
'range': [0, 1]}, | |||
'alpha': {'dtype': [float], | |||
'range': [0, 1]}, | |||
'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]}, | |||
'alpha': {'dtype': [float], 'range': [0, 1]}, | |||
'bounds': {'dtype': [tuple]}}, | |||
'PGD': {'eps': {'dtype': [float], | |||
'range': [0, 1]}, | |||
'eps_iter': { | |||
'dtype': [float], | |||
'range': [0, 1]}, | |||
'nb_iter': {'dtype': [int], | |||
'range': [0, 100000]}, | |||
'PGD': {'eps': {'dtype': [float], 'range': [0, 1]}, | |||
'eps_iter': {'dtype': [float], 'range': [0, 1]}, | |||
'nb_iter': {'dtype': [int], 'range': [0, 100000]}, | |||
'bounds': {'dtype': [tuple]}}, | |||
'MDIIM': {'eps': {'dtype': [float], | |||
'range': [0, 1]}, | |||
'norm_level': {'dtype': [str, int], | |||
'range': [1, 2, '1', '2', 'l1', 'l2', | |||
'inf', 'np.inf']}, | |||
'prob': {'dtype': [float], | |||
'range': [0, 1]}, | |||
'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]}, | |||
'norm_level': {'dtype': [str, int], 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf']}, | |||
'prob': {'dtype': [float], 'range': [0, 1]}, | |||
'bounds': {'dtype': [tuple]}}} | |||
def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', | |||
@@ -239,13 +223,11 @@ class Fuzzer: | |||
# Check parameters. | |||
eval_metrics_ = _check_eval_metrics(eval_metrics) | |||
if coverage_metric not in ['KMNC', 'NBC', 'SNAC']: | |||
msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], " \ | |||
"but got {}.".format(coverage_metric) | |||
msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], but got {}.".format(coverage_metric) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
max_iters = check_int_positive('max_iters', max_iters) | |||
mutate_num_per_seed = check_int_positive('mutate_num_per_seed', | |||
mutate_num_per_seed) | |||
mutate_num_per_seed = check_int_positive('mutate_num_per_seed', mutate_num_per_seed) | |||
mutate_config = self._check_mutate_config(mutate_config) | |||
mutates = self._init_mutates(mutate_config) | |||
@@ -276,39 +258,30 @@ class Fuzzer: | |||
mutate_config, | |||
mutate_num_per_seed) | |||
# Calculate the coverages and predictions of generated samples. | |||
coverages, predicts = self._get_coverages_and_predict(mutate_samples, | |||
coverage_metric) | |||
coverages, predicts = self._get_coverages_and_predict(mutate_samples, coverage_metric) | |||
coverage_gains = _coverage_gains(coverages) | |||
for mutate, cov, pred, strategy in zip(mutate_samples, | |||
coverage_gains, | |||
predicts, mutate_strategies): | |||
for mutate, cov, pred, strategy in zip(mutate_samples, coverage_gains, predicts, mutate_strategies): | |||
fuzz_samples.append(mutate[0]) | |||
true_labels.append(mutate[1]) | |||
fuzz_preds.append(pred) | |||
fuzz_strategies.append(strategy) | |||
# if the mutate samples has coverage gains add this samples in | |||
# the initial seeds to guide new mutates. | |||
# the initial_seeds to guide new mutates. | |||
if cov > 0: | |||
initial_seeds.append(mutate) | |||
seed, initial_seeds = _select_next(initial_seeds) | |||
iter_num += 1 | |||
metrics_report = None | |||
if eval_metrics_ is not None: | |||
metrics_report = self._evaluate(fuzz_samples, | |||
true_labels, | |||
fuzz_preds, | |||
fuzz_strategies, | |||
eval_metrics_) | |||
metrics_report = self._evaluate(fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, eval_metrics_) | |||
return fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, metrics_report | |||
def _get_coverages_and_predict(self, mutate_samples, | |||
coverage_metric="KNMC"): | |||
def _get_coverages_and_predict(self, mutate_samples, coverage_metric="KNMC"): | |||
""" Calculate the coverages and predictions of generated samples.""" | |||
samples = [s[0] for s in mutate_samples] | |||
samples = np.array(samples) | |||
coverages = [] | |||
predictions = self._target_model.predict( | |||
Tensor(samples.astype(np.float32))) | |||
predictions = self._target_model.predict(Tensor(samples.astype(np.float32))) | |||
predictions = predictions.asnumpy() | |||
for index in range(len(samples)): | |||
mutate = samples[:index + 1] | |||
@@ -349,10 +322,8 @@ class Fuzzer: | |||
mutate_sample = transform.transform(seed[0]) | |||
else: | |||
for param_name in selected_param: | |||
transform.__setattr__('_' + str(param_name), | |||
selected_param[param_name]) | |||
mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), | |||
np.array([seed[1]]))[0] | |||
transform.__setattr__('_' + str(param_name), selected_param[param_name]) | |||
mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0] | |||
if method not in self._pixel_value_trans_list: | |||
only_pixel_trans = 1 | |||
mutate_sample = [mutate_sample, seed[1], only_pixel_trans] | |||
@@ -372,8 +343,7 @@ class Fuzzer: | |||
for config in mutate_config: | |||
check_param_type("config", config, dict) | |||
if set(config.keys()) != {'method', 'params'}: | |||
msg = "The key of each config must be in ('method', 'params'), " \ | |||
"but got {}.".format(set(config.keys())) | |||
msg = "The key of each config must be in ('method', 'params'), but got {}.".format(set(config.keys())) | |||
LOGGER.error(TAG, msg) | |||
raise KeyError(msg) | |||
@@ -382,8 +352,7 @@ class Fuzzer: | |||
# Method must be in the optional range. | |||
if method not in self._strategies.keys(): | |||
msg = "Config methods must be in {}, but got {}." \ | |||
.format(self._strategies.keys(), method) | |||
msg = "Config methods must be in {}, but got {}.".format(self._strategies.keys(), method) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
@@ -401,8 +370,7 @@ class Fuzzer: | |||
# Methods in `metate_config` should at least have one in the type of | |||
# pixel value based transform methods. | |||
if not has_pixel_trans: | |||
msg = "mutate methods in mutate_config should at least have one " \ | |||
"in {}".format(self._pixel_value_trans_list) | |||
msg = "mutate methods in mutate_config should at least have one in {}".format(self._pixel_value_trans_list) | |||
raise ValueError(msg) | |||
return mutate_config | |||
@@ -424,11 +392,9 @@ class Fuzzer: | |||
'but got its length as{}'.format(len(bounds)) | |||
raise ValueError(msg) | |||
for bound_value in bounds: | |||
_ = check_param_multi_types('bound', bound_value, | |||
[int, float]) | |||
_ = check_param_multi_types('bound', bound_value, [int, float]) | |||
if bounds[0] >= bounds[1]: | |||
msg = "upper bound must more than lower bound, " \ | |||
"but upper bound got {}, lower bound " \ | |||
msg = "upper bound must more than lower bound, but upper bound got {}, lower bound " \ | |||
"got {}".format(bounds[0], bounds[1]) | |||
raise ValueError(msg) | |||
elif param_name == 'norm_level': | |||
@@ -437,10 +403,7 @@ class Fuzzer: | |||
allow_type = self._attack_param_checklists[method][param_name]['dtype'] | |||
allow_range = self._attack_param_checklists[method][param_name]['range'] | |||
_ = check_param_multi_types(str(param_name), param_value, allow_type) | |||
_ = check_param_in_range(str(param_name), | |||
param_value, | |||
allow_range[0], | |||
allow_range[1]) | |||
_ = check_param_in_range(str(param_name), param_value, allow_range[0], allow_range[1]) | |||
def _init_mutates(self, mutate_config): | |||
""" Check whether the mutate_config meet the specification.""" | |||
@@ -454,8 +417,7 @@ class Fuzzer: | |||
loss_fn = self._target_model._loss_fn | |||
if loss_fn is None: | |||
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) | |||
mutates[method] = self._strategies[method](network, | |||
loss_fn=loss_fn) | |||
mutates[method] = self._strategies[method](network, loss_fn=loss_fn) | |||
return mutates | |||
def _evaluate(self, fuzz_samples, true_labels, fuzz_preds, | |||
@@ -497,8 +459,7 @@ class Fuzzer: | |||
metrics_report['Attack_success_rate'] = attack_success_rate | |||
if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: | |||
self._coverage_metrics.calculate_coverage( | |||
np.array(fuzz_samples).astype(np.float32)) | |||
self._coverage_metrics.calculate_coverage(np.array(fuzz_samples).astype(np.float32)) | |||
if metrics == 'auto' or 'kmnc' in metrics: | |||
kmnc = self._coverage_metrics.get_kmnc() | |||
@@ -56,7 +56,9 @@ class ModelCoverageMetrics: | |||
>>> train_images = np.random.random((10000, 1, 32, 32)).astype(np.float32) | |||
>>> test_images = np.random.random((5000, 1, 32, 32)).astype(np.float32) | |||
>>> model = Model(net) | |||
>>> model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||
>>> neuron_num = 10 | |||
>>> segmented_num = 1000 | |||
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
>>> model_fuzz_test.calculate_coverage(test_images) | |||
>>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
>>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
@@ -68,16 +70,14 @@ class ModelCoverageMetrics: | |||
self._segmented_num = check_int_positive('segmented_num', segmented_num) | |||
self._neuron_num = check_int_positive('neuron_num', neuron_num) | |||
if self._neuron_num > 1e+9: | |||
msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError' \ | |||
'would occur' | |||
msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError would occur' | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
self._lower_bounds = [np.inf]*self._neuron_num | |||
self._upper_bounds = [-np.inf]*self._neuron_num | |||
self._var = [0]*self._neuron_num | |||
self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in | |||
range(self._neuron_num)] | |||
self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in range(self._neuron_num)] | |||
self._lower_corner_hits = [0]*self._neuron_num | |||
self._upper_corner_hits = [0]*self._neuron_num | |||
self._bounds_get(train_dataset) | |||
@@ -99,19 +99,16 @@ class ModelCoverageMetrics: | |||
inputs = train_dataset[i*batch_size: (i + 1)*batch_size] | |||
output = self._model.predict(Tensor(inputs)).asnumpy() | |||
output_mat.append(output) | |||
lower_compare_array = np.concatenate( | |||
[output, np.array([self._lower_bounds])], axis=0) | |||
lower_compare_array = np.concatenate([output, np.array([self._lower_bounds])], axis=0) | |||
self._lower_bounds = np.min(lower_compare_array, axis=0) | |||
upper_compare_array = np.concatenate( | |||
[output, np.array([self._upper_bounds])], axis=0) | |||
upper_compare_array = np.concatenate([output, np.array([self._upper_bounds])], axis=0) | |||
self._upper_bounds = np.max(upper_compare_array, axis=0) | |||
if batches == 0: | |||
output = self._model.predict(Tensor(train_dataset)).asnumpy() | |||
self._lower_bounds = np.min(output, axis=0) | |||
self._upper_bounds = np.max(output, axis=0) | |||
output_mat.append(output) | |||
self._var = np.std(np.concatenate(np.array(output_mat), axis=0), | |||
axis=0) | |||
self._var = np.std(np.concatenate(np.array(output_mat), axis=0), axis=0) | |||
def _sections_hits_count(self, dataset, intervals): | |||
""" | |||
@@ -119,8 +116,7 @@ class ModelCoverageMetrics: | |||
Args: | |||
dataset (numpy.ndarray): Testing data. | |||
intervals (list[float]): Segmentation intervals of neurons' | |||
outputs. | |||
intervals (list[float]): Segmentation intervals of neurons' outputs. | |||
""" | |||
dataset = check_numpy_param('dataset', dataset) | |||
batch_output = self._model.predict(Tensor(dataset)).asnumpy() | |||
@@ -142,11 +138,12 @@ class ModelCoverageMetrics: | |||
dataset (numpy.ndarray): Data for fuzz test. | |||
bias_coefficient (Union[int, float]): The coefficient used | |||
for changing the neurons' output boundaries. Default: 0. | |||
batch_size (int): The number of samples in a predict batch. | |||
Default: 32. | |||
batch_size (int): The number of samples in a predict batch. Default: 32. | |||
Examples: | |||
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||
>>> neuron_num = 10 | |||
>>> segmented_num = 1000 | |||
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
>>> model_fuzz_test.calculate_coverage(test_images) | |||
""" | |||
@@ -158,12 +155,13 @@ class ModelCoverageMetrics: | |||
intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num | |||
batches = dataset.shape[0] // batch_size | |||
for i in range(batches): | |||
self._sections_hits_count( | |||
dataset[i*batch_size: (i + 1)*batch_size], intervals) | |||
self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) | |||
def get_kmnc(self): | |||
""" | |||
Get the metric of 'k-multisection neuron coverage'. | |||
Get the metric of 'k-multisection neuron coverage'. KMNC measures how | |||
thoroughly the given set of test inputs covers the range of neurons | |||
output values derived from training dataset. | |||
Returns: | |||
float, the metric of 'k-multisection neuron coverage'. | |||
@@ -176,7 +174,11 @@ class ModelCoverageMetrics: | |||
def get_nbc(self): | |||
""" | |||
Get the metric of 'neuron boundary coverage'. | |||
Get the metric of 'neuron boundary coverage' :math` NBC =(|UpperCornerNeuron| | |||
+ |LowerCornerNeuron|)/(2*|N|)`, where :math`|N|` is the number of neurons, | |||
NBC refers to the proportion of neurons whose neurons output value in | |||
the test dataset exceeds the upper and lower bounds of the corresponding | |||
neurons output value in the training dataset. | |||
Returns: | |||
float, the metric of 'neuron boundary coverage'. | |||
@@ -184,13 +186,15 @@ class ModelCoverageMetrics: | |||
Examples: | |||
>>> model_fuzz_test.get_nbc() | |||
""" | |||
nbc = (np.sum(self._lower_corner_hits) + np.sum( | |||
self._upper_corner_hits)) / (2*self._neuron_num) | |||
nbc = (np.sum(self._lower_corner_hits) + np.sum(self._upper_corner_hits)) / (2*self._neuron_num) | |||
return nbc | |||
def get_snac(self): | |||
""" | |||
Get the metric of 'strong neuron activation coverage'. | |||
:math` SNAC =|UpperCornerNeuron| / |N|`. SNAC refers to the proportion | |||
of neurons whose neurons output value in the test set exceeds the upper | |||
bounds of the corresponding neurons output value in the training set. | |||
Returns: | |||
float, the metric of 'strong neuron activation coverage'. | |||
@@ -69,8 +69,10 @@ def test_lenet_mnist_coverage_cpu(): | |||
model = Model(net) | |||
# initialize fuzz test with training dataset | |||
neuron_num = 10 | |||
segmented_num = 1000 | |||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||
model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data) | |||
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | |||
# fuzz test with original test data | |||
# get test data | |||
@@ -103,8 +105,11 @@ def test_lenet_mnist_coverage_ascend(): | |||
model = Model(net) | |||
# initialize fuzz test with training dataset | |||
neuron_num = 10 | |||
segmented_num = 1000 | |||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||
model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data) | |||
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data,) | |||
# fuzz test with original test data | |||
# get test data | |||
@@ -101,8 +101,10 @@ def test_fuzzing_ascend(): | |||
'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} | |||
] | |||
# initialize fuzz test with training dataset | |||
neuron_num = 10 | |||
segmented_num = 1000 | |||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
# fuzz test with original test data | |||
# get test data | |||
@@ -121,7 +123,7 @@ def test_fuzzing_ascend(): | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzer(model, train_images, 10, 1000) | |||
model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) | |||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||
print(metrics) | |||
@@ -137,6 +139,8 @@ def test_fuzzing_cpu(): | |||
model = Model(net) | |||
batch_size = 8 | |||
num_classe = 10 | |||
neuron_num = 10 | |||
segmented_num = 1000 | |||
mutate_config = [{'method': 'Blur', | |||
'params': {'auto_param': [True]}}, | |||
{'method': 'Contrast', | |||
@@ -148,7 +152,7 @@ def test_fuzzing_cpu(): | |||
] | |||
# initialize fuzz test with training dataset | |||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
# fuzz test with original test data | |||
# get test data | |||
@@ -167,6 +171,6 @@ def test_fuzzing_cpu(): | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzer(model, train_images, 10, 1000) | |||
model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) | |||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||
print(metrics) |