|
|
@@ -27,6 +27,48 @@ from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \ |
|
|
|
Noise, Translate, Scale, Shear, Rotate |
|
|
|
from mindarmour.attacks import FastGradientSignMethod, \ |
|
|
|
MomentumDiverseInputIterativeMethod, ProjectedGradientDescent |
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
|
|
|
|
|
LOGGER = LogUtil.get_instance() |
|
|
|
TAG = 'Fuzzer' |
|
|
|
|
|
|
|
|
|
|
|
def _select_next(initial_seeds): |
|
|
|
""" Randomly select a seed from `initial_seeds`.""" |
|
|
|
seed_num = choice(range(len(initial_seeds))) |
|
|
|
seed = initial_seeds[seed_num] |
|
|
|
del initial_seeds[seed_num] |
|
|
|
return seed, initial_seeds |
|
|
|
|
|
|
|
|
|
|
|
def _coverage_gains(coverages): |
|
|
|
""" Calculate the coverage gains of mutated samples. """ |
|
|
|
gains = [0] + coverages[:-1] |
|
|
|
gains = np.array(coverages) - np.array(gains) |
|
|
|
return gains |
|
|
|
|
|
|
|
|
|
|
|
def _is_trans_valid(seed, mutate_sample): |
|
|
|
""" Check a mutated sample is valid. If the number of changed pixels in |
|
|
|
a seed is less than pixels_change_rate*size(seed), this mutate is valid. |
|
|
|
Else check the infinite norm of seed changes, if the value of the |
|
|
|
infinite norm less than pixel_value_change_rate*255, this mutate is |
|
|
|
valid too. Otherwise the opposite. |
|
|
|
""" |
|
|
|
is_valid = False |
|
|
|
pixels_change_rate = 0.02 |
|
|
|
pixel_value_change_rate = 0.2 |
|
|
|
diff = np.array(seed - mutate_sample).flatten() |
|
|
|
size = np.shape(diff)[0] |
|
|
|
l0_norm = np.linalg.norm(diff, ord=0) |
|
|
|
linf = np.linalg.norm(diff, ord=np.inf) |
|
|
|
if l0_norm > pixels_change_rate*size: |
|
|
|
if linf < 256: |
|
|
|
is_valid = True |
|
|
|
else: |
|
|
|
if linf < pixel_value_change_rate*255: |
|
|
|
is_valid = True |
|
|
|
return is_valid |
|
|
|
|
|
|
|
|
|
|
|
class Fuzzer: |
|
|
@@ -40,71 +82,203 @@ class Fuzzer: |
|
|
|
target_model (Model): Target fuzz model. |
|
|
|
train_dataset (numpy.ndarray): Training dataset used for determining |
|
|
|
the neurons' output boundaries. |
|
|
|
segmented_num (int): The number of segmented sections of neurons' |
|
|
|
output intervals. |
|
|
|
neuron_num (int): The number of testing neurons. |
|
|
|
segmented_num (int): The number of segmented sections of neurons' |
|
|
|
output intervals. Default: 1000. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> net = Net() |
|
|
|
>>> mutate_config = [{'method': 'Blur', 'params': {'auto_param': True}}, |
|
|
|
>>> {'method': 'Contrast','params': {'factor': 2}}, |
|
|
|
>>> {'method': 'Translate', 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, |
|
|
|
>>> {'method': 'FGSM', 'params': {'eps': 0.1, 'alpha': 0.1}}] |
|
|
|
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) |
|
|
|
>>> model_fuzz_test = Fuzzer(model, train_images, 1000, 10) |
|
|
|
>>> samples, labels, preds, strategies, report = model_fuzz_test.fuzzing(mutate_config, initial_seeds) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, target_model, train_dataset, segmented_num, neuron_num): |
|
|
|
self.target_model = check_model('model', target_model, Model) |
|
|
|
self.train_dataset = check_numpy_param('train_dataset', train_dataset) |
|
|
|
self.coverage_metrics = ModelCoverageMetrics(target_model, |
|
|
|
segmented_num, |
|
|
|
neuron_num, train_dataset) |
|
|
|
def __init__(self, target_model, train_dataset, neuron_num, segmented_num=1000): |
|
|
|
self._target_model = check_model('model', target_model, Model) |
|
|
|
train_dataset = check_numpy_param('train_dataset', train_dataset) |
|
|
|
self._coverage_metrics = ModelCoverageMetrics(target_model, |
|
|
|
segmented_num, |
|
|
|
neuron_num, train_dataset) |
|
|
|
# Allowed mutate strategies so far. |
|
|
|
self.strategies = {'Contrast': Contrast, 'Brightness': Brightness, |
|
|
|
'Blur': Blur, 'Noise': Noise, 'Translate': Translate, |
|
|
|
'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, |
|
|
|
'FGSM': FastGradientSignMethod, |
|
|
|
'PGD': ProjectedGradientDescent, |
|
|
|
'MDIIM': MomentumDiverseInputIterativeMethod} |
|
|
|
self.affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] |
|
|
|
self.pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', |
|
|
|
'Noise'] |
|
|
|
self.attacks_list = ['FGSM', 'PGD', 'MDIIM'] |
|
|
|
self.attack_param_checklists = { |
|
|
|
self._strategies = {'Contrast': Contrast, 'Brightness': Brightness, |
|
|
|
'Blur': Blur, 'Noise': Noise, 'Translate': Translate, |
|
|
|
'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, |
|
|
|
'FGSM': FastGradientSignMethod, |
|
|
|
'PGD': ProjectedGradientDescent, |
|
|
|
'MDIIM': MomentumDiverseInputIterativeMethod} |
|
|
|
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] |
|
|
|
self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', |
|
|
|
'Noise'] |
|
|
|
self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] |
|
|
|
self._attack_param_checklists = { |
|
|
|
'FGSM': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, |
|
|
|
'alpha': {'dtype': [float, int], |
|
|
|
'range': [0, 1]}, |
|
|
|
'bounds': {'dtype': [list, tuple], |
|
|
|
'range': None}, |
|
|
|
}}, |
|
|
|
'range': None}}}, |
|
|
|
'PGD': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, |
|
|
|
'eps_iter': {'dtype': [float, int], |
|
|
|
'range': [0, 1e5]}, |
|
|
|
'nb_iter': {'dtype': [float, int], |
|
|
|
'range': [0, 1e5]}, |
|
|
|
'bounds': {'dtype': [list, tuple], |
|
|
|
'range': None}, |
|
|
|
}}, |
|
|
|
'range': None}}}, |
|
|
|
'MDIIM': { |
|
|
|
'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, |
|
|
|
'norm_level': {'dtype': [str], 'range': None}, |
|
|
|
'prob': {'dtype': [float, int], 'range': [0, 1]}, |
|
|
|
'bounds': {'dtype': [list, tuple], 'range': None}, |
|
|
|
}}} |
|
|
|
'bounds': {'dtype': [list, tuple], 'range': None}}}} |
|
|
|
|
|
|
|
def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', |
|
|
|
eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20): |
|
|
|
""" |
|
|
|
Fuzzing tests for deep neural networks. |
|
|
|
|
|
|
|
Args: |
|
|
|
mutate_config (list): Mutate configs. The format is |
|
|
|
[{'method': 'Blur', 'params': {'auto_param': True}}, {'method': 'Contrast', 'params': {'factor': 2}}]. |
|
|
|
The support methods list is in `self._strategies`, and the params of each |
|
|
|
method must within the range of changeable parameters. |
|
|
|
initial_seeds (numpy.ndarray): Initial seeds used to generate |
|
|
|
mutated samples. |
|
|
|
coverage_metric (str): Model coverage metric of neural networks. |
|
|
|
Default: 'KMNC'. |
|
|
|
eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the type is 'auto', |
|
|
|
it will calculate all the metrics, else if the type is list or tuple, it will |
|
|
|
calculate the metrics specified by user. Default: 'auto'. |
|
|
|
max_iters (int): Max number of select a seed to mutate. |
|
|
|
Default: 10000. |
|
|
|
mutate_num_per_seed (int): The number of mutate times for a seed. |
|
|
|
Default: 20. |
|
|
|
|
|
|
|
Returns: |
|
|
|
- list, mutated samples in fuzzing. |
|
|
|
|
|
|
|
- list, ground truth labels of mutated samples. |
|
|
|
|
|
|
|
- list, preds of mutated samples. |
|
|
|
|
|
|
|
- list, strategies of mutated samples. |
|
|
|
|
|
|
|
- dict, metrics report of fuzzer. |
|
|
|
|
|
|
|
Raises: |
|
|
|
TypeError: If the type of `eval_metrics` is not str, list or tuple. |
|
|
|
TypeError: If the type of metric in list `eval_metrics` is not str. |
|
|
|
ValueError: If `eval_metrics` is not equal to 'auto' when it's type is str. |
|
|
|
ValueError: If metric in list `eval_metrics` is not in ['accuracy', 'attack_success_rate', |
|
|
|
'kmnc', 'nbc', 'snac']. |
|
|
|
""" |
|
|
|
eval_metrics_ = None |
|
|
|
if isinstance(eval_metrics, (list, tuple)): |
|
|
|
eval_metrics_ = [] |
|
|
|
avaliable_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] |
|
|
|
for elem in eval_metrics: |
|
|
|
if not isinstance(elem, str): |
|
|
|
msg = 'the type of metric in list `eval_metrics` must be str, but got {}.' \ |
|
|
|
.format(type(elem)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
if elem not in avaliable_metrics: |
|
|
|
msg = 'metric in list `eval_metrics` must be in {}, but got {}.' \ |
|
|
|
.format(avaliable_metrics, elem) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
eval_metrics_.append(elem.lower()) |
|
|
|
elif isinstance(eval_metrics, str): |
|
|
|
if eval_metrics != 'auto': |
|
|
|
msg = "the value of `eval_metrics` must be 'auto' if it's type is str, " \ |
|
|
|
"but got {}.".format(eval_metrics) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
eval_metrics_ = 'auto' |
|
|
|
else: |
|
|
|
msg = "the type of `eval_metrics` must be str, list or tuple, but got {}." \ |
|
|
|
.format(type(eval_metrics)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
# Check whether the mutate_config meet the specification. |
|
|
|
mutates = self._init_mutates(mutate_config) |
|
|
|
seed, initial_seeds = _select_next(initial_seeds) |
|
|
|
fuzz_samples = [] |
|
|
|
gt_labels = [] |
|
|
|
fuzz_preds = [] |
|
|
|
fuzz_strategies = [] |
|
|
|
iter_num = 0 |
|
|
|
while initial_seeds and iter_num < max_iters: |
|
|
|
# Mutate a seed. |
|
|
|
mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, |
|
|
|
mutates, |
|
|
|
mutate_config, |
|
|
|
mutate_num_per_seed) |
|
|
|
# Calculate the coverages and predictions of generated samples. |
|
|
|
coverages, predicts = self._run(mutate_samples, coverage_metric) |
|
|
|
coverage_gains = _coverage_gains(coverages) |
|
|
|
for mutate, cov, pred, strategy in zip(mutate_samples, |
|
|
|
coverage_gains, |
|
|
|
predicts, mutate_strategies): |
|
|
|
fuzz_samples.append(mutate[0]) |
|
|
|
gt_labels.append(mutate[1]) |
|
|
|
fuzz_preds.append(pred) |
|
|
|
fuzz_strategies.append(strategy) |
|
|
|
# if the mutate samples has coverage gains add this samples in |
|
|
|
# the initial seeds to guide new mutates. |
|
|
|
if cov > 0: |
|
|
|
initial_seeds.append(mutate) |
|
|
|
seed, initial_seeds = _select_next(initial_seeds) |
|
|
|
iter_num += 1 |
|
|
|
metrics_report = None |
|
|
|
if eval_metrics_ is not None: |
|
|
|
metrics_report = self._evaluate(fuzz_samples, gt_labels, fuzz_preds, |
|
|
|
fuzz_strategies, eval_metrics_) |
|
|
|
return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics_report |
|
|
|
|
|
|
|
def _run(self, mutate_samples, coverage_metric="KNMC"): |
|
|
|
""" Calculate the coverages and predictions of generated samples.""" |
|
|
|
samples = [s[0] for s in mutate_samples] |
|
|
|
samples = np.array(samples) |
|
|
|
coverages = [] |
|
|
|
predictions = self._target_model.predict(Tensor(samples.astype(np.float32))) |
|
|
|
predictions = predictions.asnumpy() |
|
|
|
for index in range(len(samples)): |
|
|
|
mutate = samples[:index + 1] |
|
|
|
self._coverage_metrics.calculate_coverage(mutate.astype(np.float32)) |
|
|
|
if coverage_metric == "KMNC": |
|
|
|
coverages.append(self._coverage_metrics.get_kmnc()) |
|
|
|
if coverage_metric == 'NBC': |
|
|
|
coverages.append(self._coverage_metrics.get_nbc()) |
|
|
|
if coverage_metric == 'SNAC': |
|
|
|
coverages.append(self._coverage_metrics.get_snac()) |
|
|
|
return coverages, predictions |
|
|
|
|
|
|
|
def _check_attack_params(self, method, params): |
|
|
|
"""Check input parameters of attack methods.""" |
|
|
|
allow_params = self.attack_param_checklists[method]['params'].keys() |
|
|
|
for p in params: |
|
|
|
if p not in allow_params: |
|
|
|
allow_params = self._attack_param_checklists[method]['params'].keys() |
|
|
|
for param_name in params: |
|
|
|
if param_name not in allow_params: |
|
|
|
msg = "parameters of {} must in {}".format(method, allow_params) |
|
|
|
raise ValueError(msg) |
|
|
|
if p == 'bounds': |
|
|
|
bounds = check_param_multi_types('bounds', params[p], |
|
|
|
|
|
|
|
param_value = params[param_name] |
|
|
|
if param_name == 'bounds': |
|
|
|
bounds = check_param_multi_types('bounds', param_value, |
|
|
|
[list, tuple]) |
|
|
|
for b in bounds: |
|
|
|
_ = check_param_multi_types('bound', b, [int, float]) |
|
|
|
elif p == 'norm_level': |
|
|
|
_ = check_norm_level(params[p]) |
|
|
|
for bound_value in bounds: |
|
|
|
_ = check_param_multi_types('bound', bound_value, [int, float]) |
|
|
|
elif param_name == 'norm_level': |
|
|
|
_ = check_norm_level(param_value) |
|
|
|
else: |
|
|
|
allow_type = self.attack_param_checklists[method]['params'][p][ |
|
|
|
allow_type = self._attack_param_checklists[method]['params'][param_name][ |
|
|
|
'dtype'] |
|
|
|
allow_range = self.attack_param_checklists[method]['params'][p][ |
|
|
|
allow_range = self._attack_param_checklists[method]['params'][param_name][ |
|
|
|
'range'] |
|
|
|
_ = check_param_multi_types(str(p), params[p], allow_type) |
|
|
|
_ = check_param_in_range(str(p), params[p], allow_range[0], |
|
|
|
_ = check_param_multi_types(str(param_name), param_value, allow_type) |
|
|
|
_ = check_param_in_range(str(param_name), param_value, allow_range[0], |
|
|
|
allow_range[1]) |
|
|
|
|
|
|
|
def _metamorphic_mutate(self, seed, mutates, mutate_config, |
|
|
@@ -117,23 +291,23 @@ class Fuzzer: |
|
|
|
strage = choice(mutate_config) |
|
|
|
# Choose a pixel value based transform method |
|
|
|
if only_pixel_trans: |
|
|
|
while strage['method'] not in self.pixel_value_trans_list: |
|
|
|
while strage['method'] not in self._pixel_value_trans_list: |
|
|
|
strage = choice(mutate_config) |
|
|
|
transform = mutates[strage['method']] |
|
|
|
params = strage['params'] |
|
|
|
method = strage['method'] |
|
|
|
if method in list(self.pixel_value_trans_list + self.affine_trans_list): |
|
|
|
if method in list(self._pixel_value_trans_list + self._affine_trans_list): |
|
|
|
transform.set_params(**params) |
|
|
|
mutate_sample = transform.transform(seed[0]) |
|
|
|
else: |
|
|
|
for p in params: |
|
|
|
transform.__setattr__('_'+str(p), params[p]) |
|
|
|
for param_name in params: |
|
|
|
transform.__setattr__('_' + str(param_name), params[param_name]) |
|
|
|
mutate_sample = transform.generate([seed[0].astype(np.float32)], |
|
|
|
[seed[1]])[0] |
|
|
|
if method not in self.pixel_value_trans_list: |
|
|
|
if method not in self._pixel_value_trans_list: |
|
|
|
only_pixel_trans = 1 |
|
|
|
mutate_sample = [mutate_sample, seed[1], only_pixel_trans] |
|
|
|
if self._is_trans_valid(seed[0], mutate_sample[0]): |
|
|
|
if _is_trans_valid(seed[0], mutate_sample[0]): |
|
|
|
mutate_samples.append(mutate_sample) |
|
|
|
mutate_strategies.append(method) |
|
|
|
if not mutate_samples: |
|
|
@@ -145,29 +319,29 @@ class Fuzzer: |
|
|
|
""" Check whether the mutate_config meet the specification.""" |
|
|
|
has_pixel_trans = False |
|
|
|
for mutate in mutate_config: |
|
|
|
if mutate['method'] in self.pixel_value_trans_list: |
|
|
|
if mutate['method'] in self._pixel_value_trans_list: |
|
|
|
has_pixel_trans = True |
|
|
|
break |
|
|
|
if not has_pixel_trans: |
|
|
|
msg = "mutate methods in mutate_config at lease have one in {}".format( |
|
|
|
self.pixel_value_trans_list) |
|
|
|
self._pixel_value_trans_list) |
|
|
|
raise ValueError(msg) |
|
|
|
mutates = {} |
|
|
|
for mutate in mutate_config: |
|
|
|
method = mutate['method'] |
|
|
|
params = mutate['params'] |
|
|
|
if method not in self.attacks_list: |
|
|
|
mutates[method] = self.strategies[method]() |
|
|
|
if method not in self._attacks_list: |
|
|
|
mutates[method] = self._strategies[method]() |
|
|
|
else: |
|
|
|
self._check_attack_params(method, params) |
|
|
|
network = self.target_model._network |
|
|
|
loss_fn = self.target_model._loss_fn |
|
|
|
mutates[method] = self.strategies[method](network, |
|
|
|
loss_fn=loss_fn) |
|
|
|
network = self._target_model._network |
|
|
|
loss_fn = self._target_model._loss_fn |
|
|
|
mutates[method] = self._strategies[method](network, |
|
|
|
loss_fn=loss_fn) |
|
|
|
return mutates |
|
|
|
|
|
|
|
def evaluate(self, fuzz_samples, gt_labels, fuzz_preds, |
|
|
|
fuzz_strategies): |
|
|
|
def _evaluate(self, fuzz_samples, gt_labels, fuzz_preds, |
|
|
|
fuzz_strategies, metrics): |
|
|
|
""" |
|
|
|
Evaluate generated fuzzing samples in three dimention: accuracy, |
|
|
|
attack success rate and neural coverage. |
|
|
@@ -177,147 +351,40 @@ class Fuzzer: |
|
|
|
gt_labels (numpy.ndarray): Ground Truth of seeds. |
|
|
|
fuzz_preds (numpy.ndarray): Predictions of generated fuzz samples. |
|
|
|
fuzz_strategies (numpy.ndarray): Mutate strategies of fuzz samples. |
|
|
|
metrics (Union[list, tuple, str]): evaluation metrics. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict, evaluate metrics include accuarcy, attack success rate |
|
|
|
and neural coverage. |
|
|
|
""" |
|
|
|
|
|
|
|
gt_labels = np.asarray(gt_labels) |
|
|
|
fuzz_preds = np.asarray(fuzz_preds) |
|
|
|
temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) |
|
|
|
acc = np.sum(temp) / np.size(temp) |
|
|
|
|
|
|
|
cond = [elem in self.attacks_list for elem in fuzz_strategies] |
|
|
|
temp = temp[cond] |
|
|
|
attack_success_rate = 1 - np.sum(temp) / np.size(temp) |
|
|
|
|
|
|
|
self.coverage_metrics.calculate_coverage( |
|
|
|
np.array(fuzz_samples).astype(np.float32)) |
|
|
|
kmnc = self.coverage_metrics.get_kmnc() |
|
|
|
nbc = self.coverage_metrics.get_nbc() |
|
|
|
snac = self.coverage_metrics.get_snac() |
|
|
|
|
|
|
|
metrics = {} |
|
|
|
metrics['Accuracy'] = acc |
|
|
|
metrics['Attack_succrss_rate'] = attack_success_rate |
|
|
|
metrics['Neural_coverage_KMNC'] = kmnc |
|
|
|
metrics['Neural_coverage_NBC'] = nbc |
|
|
|
metrics['Neural_coverage_SNAC'] = snac |
|
|
|
return metrics |
|
|
|
metrics_report = {} |
|
|
|
if metrics == 'auto' or 'accuracy' in metrics: |
|
|
|
gt_labels = np.asarray(gt_labels) |
|
|
|
fuzz_preds = np.asarray(fuzz_preds) |
|
|
|
acc = np.sum(temp) / np.size(temp) |
|
|
|
metrics_report['Accuracy'] = acc |
|
|
|
|
|
|
|
def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', |
|
|
|
eval_metric=True, max_iters=10000, mutate_num_per_seed=20): |
|
|
|
""" |
|
|
|
Fuzzing tests for deep neural networks. |
|
|
|
if metrics == 'auto' or 'attack_success_rate' in metrics: |
|
|
|
cond = [elem in self._attacks_list for elem in fuzz_strategies] |
|
|
|
temp = temp[cond] |
|
|
|
attack_success_rate = 1 - np.sum(temp) / np.size(temp) |
|
|
|
metrics_report['Attack_success_rate'] = attack_success_rate |
|
|
|
|
|
|
|
Args: |
|
|
|
mutate_config (list): Mutate configs. The format is |
|
|
|
[{'method': 'Blur', |
|
|
|
'params': {'auto_param': True}}, |
|
|
|
{'method': 'Contrast', |
|
|
|
'params': {'factor': 2}}, |
|
|
|
...]. The support methods list is in `self.strategies`, |
|
|
|
The params of each method must within the range of changeable |
|
|
|
parameters. |
|
|
|
initial_seeds (numpy.ndarray): Initial seeds used to generate |
|
|
|
mutated samples. |
|
|
|
coverage_metric (str): Model coverage metric of neural networks. |
|
|
|
Default: 'KMNC'. |
|
|
|
eval_metric (bool): Whether to evaluate the generated fuzz samples. |
|
|
|
Default: True. |
|
|
|
max_iters (int): Max number of select a seed to mutate. |
|
|
|
Default: 10000. |
|
|
|
mutate_num_per_seed (int): The number of mutate times for a seed. |
|
|
|
Default: 20. |
|
|
|
if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: |
|
|
|
self._coverage_metrics.calculate_coverage( |
|
|
|
np.array(fuzz_samples).astype(np.float32)) |
|
|
|
|
|
|
|
Returns: |
|
|
|
list, mutated samples. |
|
|
|
""" |
|
|
|
# Check whether the mutate_config meet the specification. |
|
|
|
mutates = self._init_mutates(mutate_config) |
|
|
|
seed, initial_seeds = self._select_next(initial_seeds) |
|
|
|
fuzz_samples = [] |
|
|
|
gt_labels = [] |
|
|
|
fuzz_preds = [] |
|
|
|
fuzz_strategies = [] |
|
|
|
iter_num = 0 |
|
|
|
while initial_seeds and iter_num < max_iters: |
|
|
|
# Mutate a seed. |
|
|
|
mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, |
|
|
|
mutates, |
|
|
|
mutate_config, |
|
|
|
mutate_num_per_seed) |
|
|
|
# Calculate the coverages and predictions of generated samples. |
|
|
|
coverages, predicts = self._run(mutate_samples, coverage_metric) |
|
|
|
coverage_gains = self._coverage_gains(coverages) |
|
|
|
for mutate, cov, pred, strategy in zip(mutate_samples, |
|
|
|
coverage_gains, |
|
|
|
predicts, mutate_strategies): |
|
|
|
fuzz_samples.append(mutate[0]) |
|
|
|
gt_labels.append(mutate[1]) |
|
|
|
fuzz_preds.append(pred) |
|
|
|
fuzz_strategies.append(strategy) |
|
|
|
# if the mutate samples has coverage gains add this samples in |
|
|
|
# the initial seeds to guide new mutates. |
|
|
|
if cov > 0: |
|
|
|
initial_seeds.append(mutate) |
|
|
|
seed, initial_seeds = self._select_next(initial_seeds) |
|
|
|
iter_num += 1 |
|
|
|
metrics = None |
|
|
|
if eval_metric: |
|
|
|
metrics = self.evaluate(fuzz_samples, gt_labels, fuzz_preds, |
|
|
|
fuzz_strategies) |
|
|
|
return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics |
|
|
|
|
|
|
|
def _coverage_gains(self, coverages): |
|
|
|
""" Calculate the coverage gains of mutated samples. """ |
|
|
|
gains = [0] + coverages[:-1] |
|
|
|
gains = np.array(coverages) - np.array(gains) |
|
|
|
return gains |
|
|
|
if metrics == 'auto' or 'kmnc' in metrics: |
|
|
|
kmnc = self._coverage_metrics.get_kmnc() |
|
|
|
metrics_report['Neural_coverage_KMNC'] = kmnc |
|
|
|
|
|
|
|
def _run(self, mutate_samples, coverage_metric="KNMC"): |
|
|
|
""" Calculate the coverages and predictions of generated samples.""" |
|
|
|
samples = [s[0] for s in mutate_samples] |
|
|
|
samples = np.array(samples) |
|
|
|
coverages = [] |
|
|
|
predictions = self.target_model.predict(Tensor(samples.astype(np.float32))) |
|
|
|
predictions = predictions.asnumpy() |
|
|
|
for index in range(len(samples)): |
|
|
|
mutate = samples[:index + 1] |
|
|
|
self.coverage_metrics.calculate_coverage(mutate.astype(np.float32)) |
|
|
|
if coverage_metric == "KMNC": |
|
|
|
coverages.append(self.coverage_metrics.get_kmnc()) |
|
|
|
if coverage_metric == 'NBC': |
|
|
|
coverages.append(self.coverage_metrics.get_nbc()) |
|
|
|
if coverage_metric == 'SNAC': |
|
|
|
coverages.append(self.coverage_metrics.get_snac()) |
|
|
|
return coverages, predictions |
|
|
|
if metrics == 'auto' or 'nbc' in metrics: |
|
|
|
nbc = self._coverage_metrics.get_nbc() |
|
|
|
metrics_report['Neural_coverage_NBC'] = nbc |
|
|
|
|
|
|
|
def _select_next(self, initial_seeds): |
|
|
|
"""Randomly select a seed from `initial_seeds`.""" |
|
|
|
seed_num = choice(range(len(initial_seeds))) |
|
|
|
seed = initial_seeds[seed_num] |
|
|
|
del initial_seeds[seed_num] |
|
|
|
return seed, initial_seeds |
|
|
|
|
|
|
|
def _is_trans_valid(self, seed, mutate_sample): |
|
|
|
""" Check a mutated sample is valid. If the number of changed pixels in |
|
|
|
a seed is less than pixels_change_rate*size(seed), this mutate is valid. |
|
|
|
Else check the infinite norm of seed changes, if the value of the |
|
|
|
infinite norm less than pixel_value_change_rate*255, this mutate is |
|
|
|
valid too. Otherwise the opposite.""" |
|
|
|
is_valid = False |
|
|
|
pixels_change_rate = 0.02 |
|
|
|
pixel_value_change_rate = 0.2 |
|
|
|
diff = np.array(seed - mutate_sample).flatten() |
|
|
|
size = np.shape(diff)[0] |
|
|
|
l0 = np.linalg.norm(diff, ord=0) |
|
|
|
linf = np.linalg.norm(diff, ord=np.inf) |
|
|
|
if l0 > pixels_change_rate*size: |
|
|
|
if linf < 256: |
|
|
|
is_valid = True |
|
|
|
else: |
|
|
|
if linf < pixel_value_change_rate*255: |
|
|
|
is_valid = True |
|
|
|
return is_valid |
|
|
|
if metrics == 'auto' or 'snac' in metrics: |
|
|
|
snac = self._coverage_metrics.get_snac() |
|
|
|
metrics_report['Neural_coverage_SNAC'] = snac |
|
|
|
|
|
|
|
return metrics_report |