@@ -19,7 +19,7 @@ from mindspore import context | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from lenet5_net import LeNet5 | |||
from mindarmour.fuzzing.fuzzing import Fuzzing | |||
from mindarmour.fuzzing.fuzzing import Fuzzer | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
from mindarmour.utils.logger import LogUtil | |||
@@ -38,11 +38,20 @@ def test_lenet_mnist_fuzzing(): | |||
load_dict = load_checkpoint(ckpt_name) | |||
load_param_into_net(net, load_dict) | |||
model = Model(net) | |||
mutate_config = [{'method': 'Blur', | |||
'params': {'auto_param': True}}, | |||
{'method': 'Contrast', | |||
'params': {'factor': 2}}, | |||
{'method': 'Translate', | |||
'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||
{'method': 'FGSM', | |||
'params': {'eps': 0.1, 'alpha': 0.1}} | |||
] | |||
# get training data | |||
data_list = "./MNIST_unzip/train" | |||
batch_size = 32 | |||
ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
ds = generate_mnist_dataset(data_list, batch_size, sparse=False) | |||
train_images = [] | |||
for data in ds.create_tuple_iterator(): | |||
images = data[0].astype(np.float32) | |||
@@ -56,7 +65,7 @@ def test_lenet_mnist_fuzzing(): | |||
# get test data | |||
data_list = "./MNIST_unzip/test" | |||
batch_size = 32 | |||
ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
ds = generate_mnist_dataset(data_list, batch_size, sparse=False) | |||
test_images = [] | |||
test_labels = [] | |||
for data in ds.create_tuple_iterator(): | |||
@@ -70,19 +79,20 @@ def test_lenet_mnist_fuzzing(): | |||
# make initial seeds | |||
for img, label in zip(test_images, test_labels): | |||
initial_seeds.append([img, label]) | |||
initial_seeds.append([img, label, 0]) | |||
initial_seeds = initial_seeds[:100] | |||
model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
model_coverage_test.calculate_coverage( | |||
np.array(test_images[:100]).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20) | |||
failed_tests = model_fuzz_test.fuzzing() | |||
if failed_tests: | |||
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
else: | |||
LOGGER.info(TAG, 'Fuzzing test identifies none failed test') | |||
model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, | |||
eval_metric=True) | |||
if metrics: | |||
for key in metrics: | |||
LOGGER.info(TAG, key + ': %s', metrics[key]) | |||
if __name__ == '__main__': | |||
@@ -227,8 +227,8 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
clip_min, clip_max = self._bounds | |||
clip_diff = clip_max - clip_min | |||
for _ in range(self._nb_iter): | |||
if 'self.prob' in globals(): | |||
d_inputs = _transform_inputs(inputs, self.prob) | |||
if 'self._prob' in globals(): | |||
d_inputs = _transform_inputs(inputs, self._prob) | |||
else: | |||
d_inputs = inputs | |||
adv_x = self._attack.generate(d_inputs, labels) | |||
@@ -238,8 +238,8 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
inputs = adv_x | |||
else: | |||
for _ in range(self._nb_iter): | |||
if 'self.prob' in globals(): | |||
d_inputs = _transform_inputs(inputs, self.prob) | |||
if 'self._prob' in globals(): | |||
d_inputs = _transform_inputs(inputs, self._prob) | |||
else: | |||
d_inputs = inputs | |||
adv_x = self._attack.generate(d_inputs, labels) | |||
@@ -311,8 +311,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
clip_min, clip_max = self._bounds | |||
clip_diff = clip_max - clip_min | |||
for _ in range(self._nb_iter): | |||
if 'self.prob' in globals(): | |||
d_inputs = _transform_inputs(inputs, self.prob) | |||
if 'self._prob' in globals(): | |||
d_inputs = _transform_inputs(inputs, self._prob) | |||
else: | |||
d_inputs = inputs | |||
gradient = self._gradient(d_inputs, labels) | |||
@@ -325,8 +325,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
inputs = adv_x | |||
else: | |||
for _ in range(self._nb_iter): | |||
if 'self.prob' in globals(): | |||
d_inputs = _transform_inputs(inputs, self.prob) | |||
if 'self._prob' in globals(): | |||
d_inputs = _transform_inputs(inputs, self._prob) | |||
else: | |||
d_inputs = inputs | |||
gradient = self._gradient(d_inputs, labels) | |||
@@ -476,7 +476,7 @@ class DiverseInputIterativeMethod(BasicIterativeMethod): | |||
is_targeted=is_targeted, | |||
nb_iter=nb_iter, | |||
loss_fn=loss_fn) | |||
self.prob = check_param_type('prob', prob, float) | |||
self._prob = check_param_type('prob', prob, float) | |||
class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||
@@ -511,7 +511,7 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||
is_targeted=is_targeted, | |||
norm_level=norm_level, | |||
loss_fn=loss_fn) | |||
self.prob = check_param_type('prob', prob, float) | |||
self._prob = check_param_type('prob', prob, float) | |||
def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False): | |||
@@ -22,9 +22,11 @@ from mindspore import Tensor | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
check_int_positive | |||
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
Translate, Scale, Shear, Rotate | |||
check_param_multi_types, check_norm_level, check_param_in_range | |||
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \ | |||
Noise, Translate, Scale, Shear, Rotate | |||
from mindarmour.attacks import FastGradientSignMethod, \ | |||
MomentumDiverseInputIterativeMethod, ProjectedGradientDescent | |||
class Fuzzer: | |||
@@ -35,129 +37,280 @@ class Fuzzer: | |||
Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_ | |||
Args: | |||
initial_seeds (list): Initial fuzzing seed, format: [[image, label], | |||
[image, label], ...]. | |||
target_model (Model): Target fuzz model. | |||
train_dataset (numpy.ndarray): Training dataset used for determining | |||
the neurons' output boundaries. | |||
const_k (int): The number of mutate tests for a seed. | |||
mode (str): Image mode used in image transform, 'L' means grey graph. | |||
Default: 'L'. | |||
max_seed_num (int): The initial seeds max value. Default: 1000 | |||
segmented_num (int): The number of segmented sections of neurons' | |||
output intervals. | |||
neuron_num (int): The number of testing neurons. | |||
""" | |||
def __init__(self, initial_seeds, target_model, train_dataset, const_K, | |||
mode='L', max_seed_num=1000): | |||
self.initial_seeds = initial_seeds | |||
def __init__(self, target_model, train_dataset, segmented_num, neuron_num): | |||
self.target_model = check_model('model', target_model, Model) | |||
self.train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
self.const_k = check_int_positive('const_k', const_K) | |||
self.mode = mode | |||
self.max_seed_num = check_int_positive('max_seed_num', max_seed_num) | |||
self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10, | |||
train_dataset) | |||
def _image_value_expand(self, image): | |||
return image*255 | |||
def _image_value_compress(self, image): | |||
return image / 255 | |||
def _metamorphic_mutate(self, seed, try_num=50): | |||
if self.mode == 'L': | |||
seed = seed[0] | |||
info = [seed, seed] | |||
mutate_tests = [] | |||
pixel_value_trans = ['Contrast', 'Brightness', 'Blur', 'Noise'] | |||
affine_trans = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||
strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur, | |||
'Noise': Noise, | |||
'Translate': Translate, 'Scale': Scale, 'Shear': Shear, | |||
'Rotate': Rotate} | |||
for _ in range(self.const_k): | |||
for _ in range(try_num): | |||
if (info[0] == info[1]).all(): | |||
trans_strage = self._random_pick_mutate(affine_trans, | |||
pixel_value_trans) | |||
else: | |||
trans_strage = self._random_pick_mutate(pixel_value_trans, | |||
[]) | |||
transform = strages[trans_strage]( | |||
self._image_value_expand(seed), self.mode) | |||
transform.set_params(auto_param=True) | |||
mutate_test = transform.transform() | |||
mutate_test = np.expand_dims( | |||
self._image_value_compress(mutate_test), 0) | |||
if self._is_trans_valid(seed, mutate_test): | |||
if trans_strage in affine_trans: | |||
info[1] = mutate_test | |||
mutate_tests.append(mutate_test) | |||
if not mutate_tests: | |||
mutate_tests.append(seed) | |||
return np.array(mutate_tests) | |||
def fuzzing(self, coverage_metric='KMNC'): | |||
self.coverage_metrics = ModelCoverageMetrics(target_model, | |||
segmented_num, | |||
neuron_num, train_dataset) | |||
# Allowed mutate strategies so far. | |||
self.strategies = {'Contrast': Contrast, 'Brightness': Brightness, | |||
'Blur': Blur, 'Noise': Noise, 'Translate': Translate, | |||
'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, | |||
'FGSM': FastGradientSignMethod, | |||
'PGD': ProjectedGradientDescent, | |||
'MDIIM': MomentumDiverseInputIterativeMethod} | |||
self.affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||
self.pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', | |||
'Noise'] | |||
self.attacks_list = ['FGSM', 'PGD', 'MDIIM'] | |||
self.attack_param_checklists = { | |||
'FGSM': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||
'alpha': {'dtype': [float, int], | |||
'range': [0, 1]}, | |||
'bounds': {'dtype': [list, tuple], | |||
'range': None}, | |||
}}, | |||
'PGD': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||
'eps_iter': {'dtype': [float, int], | |||
'range': [0, 1e5]}, | |||
'nb_iter': {'dtype': [float, int], | |||
'range': [0, 1e5]}, | |||
'bounds': {'dtype': [list, tuple], | |||
'range': None}, | |||
}}, | |||
'MDIIM': { | |||
'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||
'norm_level': {'dtype': [str], 'range': None}, | |||
'prob': {'dtype': [float, int], 'range': [0, 1]}, | |||
'bounds': {'dtype': [list, tuple], 'range': None}, | |||
}}} | |||
def _check_attack_params(self, method, params): | |||
"""Check input parameters of attack methods.""" | |||
allow_params = self.attack_param_checklists[method]['params'].keys() | |||
for p in params: | |||
if p not in allow_params: | |||
msg = "parameters of {} must in {}".format(method, allow_params) | |||
raise ValueError(msg) | |||
if p == 'bounds': | |||
bounds = check_param_multi_types('bounds', params[p], | |||
[list, tuple]) | |||
for b in bounds: | |||
_ = check_param_multi_types('bound', b, [int, float]) | |||
elif p == 'norm_level': | |||
_ = check_norm_level(params[p]) | |||
else: | |||
allow_type = self.attack_param_checklists[method]['params'][p][ | |||
'dtype'] | |||
allow_range = self.attack_param_checklists[method]['params'][p][ | |||
'range'] | |||
_ = check_param_multi_types(str(p), params[p], allow_type) | |||
_ = check_param_in_range(str(p), params[p], allow_range[0], | |||
allow_range[1]) | |||
def _metamorphic_mutate(self, seed, mutates, mutate_config, | |||
mutate_num_per_seed): | |||
"""Mutate a seed using strategies random selected from mutate_config.""" | |||
mutate_samples = [] | |||
mutate_strategies = [] | |||
only_pixel_trans = seed[2] | |||
for _ in range(mutate_num_per_seed): | |||
strage = choice(mutate_config) | |||
# Choose a pixel value based transform method | |||
if only_pixel_trans: | |||
while strage['method'] not in self.pixel_value_trans_list: | |||
strage = choice(mutate_config) | |||
transform = mutates[strage['method']] | |||
params = strage['params'] | |||
method = strage['method'] | |||
if method in list(self.pixel_value_trans_list + self.affine_trans_list): | |||
transform.set_params(**params) | |||
mutate_sample = transform.transform(seed[0]) | |||
else: | |||
for p in params: | |||
transform.__setattr__('_'+str(p), params[p]) | |||
mutate_sample = transform.generate([seed[0].astype(np.float32)], | |||
[seed[1]])[0] | |||
if method not in self.pixel_value_trans_list: | |||
only_pixel_trans = 1 | |||
mutate_sample = [mutate_sample, seed[1], only_pixel_trans] | |||
if self._is_trans_valid(seed[0], mutate_sample[0]): | |||
mutate_samples.append(mutate_sample) | |||
mutate_strategies.append(method) | |||
if not mutate_samples: | |||
mutate_samples.append(seed) | |||
mutate_strategies.append(None) | |||
return np.array(mutate_samples), mutate_strategies | |||
def _init_mutates(self, mutate_config): | |||
""" Check whether the mutate_config meet the specification.""" | |||
has_pixel_trans = False | |||
for mutate in mutate_config: | |||
if mutate['method'] in self.pixel_value_trans_list: | |||
has_pixel_trans = True | |||
break | |||
if not has_pixel_trans: | |||
msg = "mutate methods in mutate_config at lease have one in {}".format( | |||
self.pixel_value_trans_list) | |||
raise ValueError(msg) | |||
mutates = {} | |||
for mutate in mutate_config: | |||
method = mutate['method'] | |||
params = mutate['params'] | |||
if method not in self.attacks_list: | |||
mutates[method] = self.strategies[method]() | |||
else: | |||
self._check_attack_params(method, params) | |||
network = self.target_model._network | |||
loss_fn = self.target_model._loss_fn | |||
mutates[method] = self.strategies[method](network, | |||
loss_fn=loss_fn) | |||
return mutates | |||
def evaluate(self, fuzz_samples, gt_labels, fuzz_preds, | |||
fuzz_strategies): | |||
""" | |||
Evaluate generated fuzzing samples in three dimention: accuracy, | |||
attack success rate and neural coverage. | |||
Args: | |||
fuzz_samples (numpy.ndarray): Generated fuzzing samples according to seeds. | |||
gt_labels (numpy.ndarray): Ground Truth of seeds. | |||
fuzz_preds (numpy.ndarray): Predictions of generated fuzz samples. | |||
fuzz_strategies (numpy.ndarray): Mutate strategies of fuzz samples. | |||
Returns: | |||
dict, evaluate metrics include accuarcy, attack success rate | |||
and neural coverage. | |||
""" | |||
gt_labels = np.asarray(gt_labels) | |||
fuzz_preds = np.asarray(fuzz_preds) | |||
temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) | |||
acc = np.sum(temp) / np.size(temp) | |||
cond = [elem in self.attacks_list for elem in fuzz_strategies] | |||
temp = temp[cond] | |||
attack_success_rate = 1 - np.sum(temp) / np.size(temp) | |||
self.coverage_metrics.calculate_coverage( | |||
np.array(fuzz_samples).astype(np.float32)) | |||
kmnc = self.coverage_metrics.get_kmnc() | |||
nbc = self.coverage_metrics.get_nbc() | |||
snac = self.coverage_metrics.get_snac() | |||
metrics = {} | |||
metrics['Accuracy'] = acc | |||
metrics['Attack_succrss_rate'] = attack_success_rate | |||
metrics['Neural_coverage_KMNC'] = kmnc | |||
metrics['Neural_coverage_NBC'] = nbc | |||
metrics['Neural_coverage_SNAC'] = snac | |||
return metrics | |||
def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', | |||
eval_metric=True, max_iters=10000, mutate_num_per_seed=20): | |||
""" | |||
Fuzzing tests for deep neural networks. | |||
Args: | |||
mutate_config (list): Mutate configs. The format is | |||
[{'method': 'Blur', | |||
'params': {'auto_param': True}}, | |||
{'method': 'Contrast', | |||
'params': {'factor': 2}}, | |||
...]. The support methods list is in `self.strategies`, | |||
The params of each method must within the range of changeable | |||
parameters. | |||
initial_seeds (numpy.ndarray): Initial seeds used to generate | |||
mutated samples. | |||
coverage_metric (str): Model coverage metric of neural networks. | |||
Default: 'KMNC'. | |||
eval_metric (bool): Whether to evaluate the generated fuzz samples. | |||
Default: True. | |||
max_iters (int): Max number of select a seed to mutate. | |||
Default: 10000. | |||
mutate_num_per_seed (int): The number of mutate times for a seed. | |||
Default: 20. | |||
Returns: | |||
list, mutated tests mis-predicted by target DNN model. | |||
list, mutated samples. | |||
""" | |||
seed = self._select_next() | |||
failed_tests = [] | |||
seed_num = 0 | |||
while seed and seed_num < self.max_seed_num: | |||
mutate_tests = self._metamorphic_mutate(seed[0]) | |||
coverages, predicts = self._run(mutate_tests, coverage_metric) | |||
# Check whether the mutate_config meet the specification. | |||
mutates = self._init_mutates(mutate_config) | |||
seed, initial_seeds = self._select_next(initial_seeds) | |||
fuzz_samples = [] | |||
gt_labels = [] | |||
fuzz_preds = [] | |||
fuzz_strategies = [] | |||
iter_num = 0 | |||
while initial_seeds and iter_num < max_iters: | |||
# Mutate a seed. | |||
mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, | |||
mutates, | |||
mutate_config, | |||
mutate_num_per_seed) | |||
# Calculate the coverages and predictions of generated samples. | |||
coverages, predicts = self._run(mutate_samples, coverage_metric) | |||
coverage_gains = self._coverage_gains(coverages) | |||
for mutate, cov, res in zip(mutate_tests, coverage_gains, predicts): | |||
if np.argmax(seed[1]) != np.argmax(res): | |||
failed_tests.append(mutate) | |||
continue | |||
for mutate, cov, pred, strategy in zip(mutate_samples, | |||
coverage_gains, | |||
predicts, mutate_strategies): | |||
fuzz_samples.append(mutate[0]) | |||
gt_labels.append(mutate[1]) | |||
fuzz_preds.append(pred) | |||
fuzz_strategies.append(strategy) | |||
# if the mutate samples has coverage gains add this samples in | |||
# the initial seeds to guide new mutates. | |||
if cov > 0: | |||
self.initial_seeds.append([mutate, seed[1]]) | |||
seed = self._select_next() | |||
seed_num += 1 | |||
return failed_tests | |||
initial_seeds.append(mutate) | |||
seed, initial_seeds = self._select_next(initial_seeds) | |||
iter_num += 1 | |||
metrics = None | |||
if eval_metric: | |||
metrics = self.evaluate(fuzz_samples, gt_labels, fuzz_preds, | |||
fuzz_strategies) | |||
return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics | |||
def _coverage_gains(self, coverages): | |||
""" Calculate the coverage gains of mutated samples. """ | |||
gains = [0] + coverages[:-1] | |||
gains = np.array(coverages) - np.array(gains) | |||
return gains | |||
def _run(self, mutate_tests, coverage_metric="KNMC"): | |||
def _run(self, mutate_samples, coverage_metric="KNMC"): | |||
""" Calculate the coverages and predictions of generated samples.""" | |||
samples = [s[0] for s in mutate_samples] | |||
samples = np.array(samples) | |||
coverages = [] | |||
result = self.target_model.predict( | |||
Tensor(mutate_tests.astype(np.float32))) | |||
result = result.asnumpy() | |||
for index in range(len(mutate_tests)): | |||
mutate = np.expand_dims(mutate_tests[index], 0) | |||
self.coverage_metrics.model_coverage_test( | |||
mutate.astype(np.float32), batch_size=1) | |||
predictions = self.target_model.predict(Tensor(samples.astype(np.float32))) | |||
predictions = predictions.asnumpy() | |||
for index in range(len(samples)): | |||
mutate = samples[:index + 1] | |||
self.coverage_metrics.calculate_coverage(mutate.astype(np.float32)) | |||
if coverage_metric == "KMNC": | |||
coverages.append(self.coverage_metrics.get_kmnc()) | |||
if coverage_metric == 'NBC': | |||
coverages.append(self.coverage_metrics.get_nbc()) | |||
if coverage_metric == 'SNAC': | |||
coverages.append(self.coverage_metrics.get_snac()) | |||
return coverages, predictions | |||
return coverages, result | |||
def _select_next(self): | |||
seed = choice(self.initial_seeds) | |||
return seed | |||
def _select_next(self, initial_seeds): | |||
"""Randomly select a seed from `initial_seeds`.""" | |||
seed_num = choice(range(len(initial_seeds))) | |||
seed = initial_seeds[seed_num] | |||
del initial_seeds[seed_num] | |||
return seed, initial_seeds | |||
def _random_pick_mutate(self, affine_trans_list, pixel_value_trans_list): | |||
strage = choice(affine_trans_list + pixel_value_trans_list) | |||
return strage | |||
def _is_trans_valid(self, seed, mutate_test): | |||
def _is_trans_valid(self, seed, mutate_sample): | |||
""" Check a mutated sample is valid. If the number of changed pixels in | |||
a seed is less than pixels_change_rate*size(seed), this mutate is valid. | |||
Else check the infinite norm of seed changes, if the value of the | |||
infinite norm less than pixel_value_change_rate*255, this mutate is | |||
valid too. Otherwise the opposite.""" | |||
is_valid = False | |||
pixels_change_rate = 0.02 | |||
pixel_value_change_rate = 0.2 | |||
diff = np.array(seed - mutate_test).flatten() | |||
diff = np.array(seed - mutate_sample).flatten() | |||
size = np.shape(diff)[0] | |||
l0 = np.linalg.norm(diff, ord=0) | |||
linf = np.linalg.norm(diff, ord=np.inf) | |||
@@ -167,5 +320,4 @@ class Fuzzer: | |||
else: | |||
if linf < pixel_value_change_rate*255: | |||
is_valid = True | |||
return is_valid |
@@ -88,7 +88,8 @@ def is_rgb(img): | |||
Bool, True if input is RGB. | |||
""" | |||
if is_numpy(img): | |||
if len(np.shape(img)) == 3: | |||
img_shape = np.shape(img) | |||
if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3): | |||
return True | |||
return False | |||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||
@@ -127,6 +128,7 @@ class ImageTransform: | |||
of the image is not normalized , it will be normalized between 0 to 1.""" | |||
rgb = is_rgb(image) | |||
chw = False | |||
gray3dim = False | |||
normalized = is_normalized(image) | |||
if rgb: | |||
chw = is_chw(image) | |||
@@ -135,12 +137,16 @@ class ImageTransform: | |||
else: | |||
image = image | |||
else: | |||
image = image | |||
if len(np.shape(image)) == 3: | |||
gray3dim = True | |||
image = image[0] | |||
else: | |||
image = image | |||
if normalized: | |||
image = np.uint8(image*255) | |||
return rgb, chw, normalized, image | |||
return rgb, chw, normalized, gray3dim, image | |||
def _original_format(self, image, chw, normalized): | |||
def _original_format(self, image, chw, normalized, gray3dim): | |||
""" Return transformed image with original format. """ | |||
if not is_numpy(image): | |||
image = np.array(image) | |||
@@ -148,6 +154,8 @@ class ImageTransform: | |||
image = hwc_to_chw(image) | |||
if normalized: | |||
image = image / 255 | |||
if gray3dim: | |||
image = np.expand_dims(image, 0) | |||
return image | |||
def transform(self, image): | |||
@@ -191,11 +199,12 @@ class Contrast(ImageTransform): | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
_, chw, normalized, gray3dim, image = self._check(image) | |||
image = to_pil(image) | |||
img_contrast = ImageEnhance.Contrast(image) | |||
trans_image = img_contrast.enhance(self.factor) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
trans_image = self._original_format(trans_image, chw, normalized, | |||
gray3dim) | |||
return trans_image | |||
@@ -237,11 +246,12 @@ class Brightness(ImageTransform): | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
_, chw, normalized, gray3dim, image = self._check(image) | |||
image = to_pil(image) | |||
img_contrast = ImageEnhance.Brightness(image) | |||
trans_image = img_contrast.enhance(self.factor) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
trans_image = self._original_format(trans_image, chw, normalized, | |||
gray3dim) | |||
return trans_image | |||
@@ -280,10 +290,11 @@ class Blur(ImageTransform): | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
_, chw, normalized, gray3dim, image = self._check(image) | |||
image = to_pil(image) | |||
trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
trans_image = self._original_format(trans_image, chw, normalized, | |||
gray3dim) | |||
return trans_image | |||
@@ -324,12 +335,13 @@ class Noise(ImageTransform): | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
_, chw, normalized, gray3dim, image = self._check(image) | |||
noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) | |||
trans_image = np.copy(image) | |||
trans_image[noise < -self.factor] = 0 | |||
trans_image[noise > self.factor] = 1 | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
trans_image = self._original_format(trans_image, chw, normalized, | |||
gray3dim) | |||
return trans_image | |||
@@ -375,7 +387,7 @@ class Translate(ImageTransform): | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
_, chw, normalized, gray3dim, image = self._check(image) | |||
img = to_pil(image) | |||
if self.auto_param: | |||
image_shape = np.shape(image) | |||
@@ -383,7 +395,8 @@ class Translate(ImageTransform): | |||
self.y_bias = image_shape[1]*self.y_bias | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(1, 0, self.x_bias, 0, 1, self.y_bias)) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
trans_image = self._original_format(trans_image, chw, normalized, | |||
gray3dim) | |||
return trans_image | |||
@@ -431,7 +444,7 @@ class Scale(ImageTransform): | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
rgb, chw, normalized, image = self._check(image) | |||
rgb, chw, normalized, gray3dim, image = self._check(image) | |||
if rgb: | |||
h, w, _ = np.shape(image) | |||
else: | |||
@@ -442,7 +455,8 @@ class Scale(ImageTransform): | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(self.factor_x, 0, move_x_centor, | |||
0, self.factor_y, move_y_centor)) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
trans_image = self._original_format(trans_image, chw, normalized, | |||
gray3dim) | |||
return trans_image | |||
@@ -500,7 +514,7 @@ class Shear(ImageTransform): | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
rgb, chw, normalized, image = self._check(image) | |||
rgb, chw, normalized, gray3dim, image = self._check(image) | |||
img = to_pil(image) | |||
if rgb: | |||
h, w, _ = np.shape(image) | |||
@@ -523,7 +537,8 @@ class Shear(ImageTransform): | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(scale, scale*self.factor_x, move_x_cen, | |||
scale*self.factor_y, scale, move_y_cen)) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
trans_image = self._original_format(trans_image, chw, normalized, | |||
gray3dim) | |||
return trans_image | |||
@@ -562,8 +577,9 @@ class Rotate(ImageTransform): | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
_, chw, normalized, gray3dim, image = self._check(image) | |||
img = to_pil(image) | |||
trans_image = img.rotate(self.angle, expand=True) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
trans_image = self._original_format(trans_image, chw, normalized, | |||
gray3dim) | |||
return trans_image |
@@ -0,0 +1,172 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Model-fuzz coverage test. | |||
""" | |||
import numpy as np | |||
import pytest | |||
from mindspore import context | |||
from mindspore import nn | |||
from mindspore.common.initializer import TruncatedNormal | |||
from mindspore.ops import operations as P | |||
from mindspore.train import Model | |||
from mindarmour.fuzzing.fuzzing import Fuzzer | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
from mindarmour.utils.logger import LogUtil | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Fuzzing test' | |||
LOGGER.set_level('INFO') | |||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
weight = weight_variable() | |||
return nn.Conv2d(in_channels, out_channels, | |||
kernel_size=kernel_size, stride=stride, padding=padding, | |||
weight_init=weight, has_bias=False, pad_mode="valid") | |||
def fc_with_initialize(input_channels, out_channels): | |||
weight = weight_variable() | |||
bias = weight_variable() | |||
return nn.Dense(input_channels, out_channels, weight, bias) | |||
def weight_variable(): | |||
return TruncatedNormal(0.02) | |||
class Net(nn.Cell): | |||
""" | |||
Lenet network | |||
""" | |||
def __init__(self): | |||
super(Net, self).__init__() | |||
self.conv1 = conv(1, 6, 5) | |||
self.conv2 = conv(6, 16, 5) | |||
self.fc1 = fc_with_initialize(16*5*5, 120) | |||
self.fc2 = fc_with_initialize(120, 84) | |||
self.fc3 = fc_with_initialize(84, 10) | |||
self.relu = nn.ReLU() | |||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
self.reshape = P.Reshape() | |||
def construct(self, x): | |||
x = self.conv1(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
x = self.conv2(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
x = self.reshape(x, (-1, 16*5*5)) | |||
x = self.fc1(x) | |||
x = self.relu(x) | |||
x = self.fc2(x) | |||
x = self.relu(x) | |||
x = self.fc3(x) | |||
return x | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_fuzzing_ascend(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# load network | |||
net = Net() | |||
model = Model(net) | |||
batch_size = 8 | |||
num_classe = 10 | |||
mutate_config = [{'method': 'Blur', | |||
'params': {'auto_param': True}}, | |||
{'method': 'Contrast', | |||
'params': {'factor': 2}}, | |||
{'method': 'Translate', | |||
'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||
{'method': 'FGSM', | |||
'params': {'eps': 0.1, 'alpha': 0.1}} | |||
] | |||
# initialize fuzz test with training dataset | |||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||
# fuzz test with original test data | |||
# get test data | |||
test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
initial_seeds = [] | |||
# make initial seeds | |||
for img, label in zip(test_images, test_labels): | |||
initial_seeds.append([img, label, 0]) | |||
initial_seeds = initial_seeds[:100] | |||
model_coverage_test.calculate_coverage( | |||
np.array(test_images[:100]).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||
print(metrics) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_fuzzing_cpu(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
# load network | |||
net = Net() | |||
model = Model(net) | |||
batch_size = 8 | |||
num_classe = 10 | |||
mutate_config = [{'method': 'Blur', | |||
'params': {'auto_param': True}}, | |||
{'method': 'Contrast', | |||
'params': {'factor': 2}}, | |||
{'method': 'Translate', | |||
'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||
{'method': 'FGSM', | |||
'params': {'eps': 0.1, 'alpha': 0.1}} | |||
] | |||
# initialize fuzz test with training dataset | |||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||
# fuzz test with original test data | |||
# get test data | |||
test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
initial_seeds = [] | |||
# make initial seeds | |||
for img, label in zip(test_images, test_labels): | |||
initial_seeds.append([img, label, 0]) | |||
initial_seeds = initial_seeds[:100] | |||
model_coverage_test.calculate_coverage( | |||
np.array(test_images[:100]).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||
print(metrics) |