| @@ -19,7 +19,7 @@ from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from lenet5_net import LeNet5 | |||
| from mindarmour.fuzzing.fuzzing import Fuzzing | |||
| from mindarmour.fuzzing.fuzzing import Fuzzer | |||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
| from mindarmour.utils.logger import LogUtil | |||
| @@ -38,11 +38,20 @@ def test_lenet_mnist_fuzzing(): | |||
| load_dict = load_checkpoint(ckpt_name) | |||
| load_param_into_net(net, load_dict) | |||
| model = Model(net) | |||
| mutate_config = [{'method': 'Blur', | |||
| 'params': {'auto_param': True}}, | |||
| {'method': 'Contrast', | |||
| 'params': {'factor': 2}}, | |||
| {'method': 'Translate', | |||
| 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||
| {'method': 'FGSM', | |||
| 'params': {'eps': 0.1, 'alpha': 0.1}} | |||
| ] | |||
| # get training data | |||
| data_list = "./MNIST_unzip/train" | |||
| batch_size = 32 | |||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=False) | |||
| train_images = [] | |||
| for data in ds.create_tuple_iterator(): | |||
| images = data[0].astype(np.float32) | |||
| @@ -56,7 +65,7 @@ def test_lenet_mnist_fuzzing(): | |||
| # get test data | |||
| data_list = "./MNIST_unzip/test" | |||
| batch_size = 32 | |||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=False) | |||
| test_images = [] | |||
| test_labels = [] | |||
| for data in ds.create_tuple_iterator(): | |||
| @@ -70,19 +79,20 @@ def test_lenet_mnist_fuzzing(): | |||
| # make initial seeds | |||
| for img, label in zip(test_images, test_labels): | |||
| initial_seeds.append([img, label]) | |||
| initial_seeds.append([img, label, 0]) | |||
| initial_seeds = initial_seeds[:100] | |||
| model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
| model_coverage_test.calculate_coverage( | |||
| np.array(test_images[:100]).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20) | |||
| failed_tests = model_fuzz_test.fuzzing() | |||
| if failed_tests: | |||
| model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
| else: | |||
| LOGGER.info(TAG, 'Fuzzing test identifies none failed test') | |||
| model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, | |||
| eval_metric=True) | |||
| if metrics: | |||
| for key in metrics: | |||
| LOGGER.info(TAG, key + ': %s', metrics[key]) | |||
| if __name__ == '__main__': | |||
| @@ -227,8 +227,8 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
| clip_min, clip_max = self._bounds | |||
| clip_diff = clip_max - clip_min | |||
| for _ in range(self._nb_iter): | |||
| if 'self.prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self.prob) | |||
| if 'self._prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self._prob) | |||
| else: | |||
| d_inputs = inputs | |||
| adv_x = self._attack.generate(d_inputs, labels) | |||
| @@ -238,8 +238,8 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
| inputs = adv_x | |||
| else: | |||
| for _ in range(self._nb_iter): | |||
| if 'self.prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self.prob) | |||
| if 'self._prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self._prob) | |||
| else: | |||
| d_inputs = inputs | |||
| adv_x = self._attack.generate(d_inputs, labels) | |||
| @@ -311,8 +311,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
| clip_min, clip_max = self._bounds | |||
| clip_diff = clip_max - clip_min | |||
| for _ in range(self._nb_iter): | |||
| if 'self.prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self.prob) | |||
| if 'self._prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self._prob) | |||
| else: | |||
| d_inputs = inputs | |||
| gradient = self._gradient(d_inputs, labels) | |||
| @@ -325,8 +325,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
| inputs = adv_x | |||
| else: | |||
| for _ in range(self._nb_iter): | |||
| if 'self.prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self.prob) | |||
| if 'self._prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self._prob) | |||
| else: | |||
| d_inputs = inputs | |||
| gradient = self._gradient(d_inputs, labels) | |||
| @@ -476,7 +476,7 @@ class DiverseInputIterativeMethod(BasicIterativeMethod): | |||
| is_targeted=is_targeted, | |||
| nb_iter=nb_iter, | |||
| loss_fn=loss_fn) | |||
| self.prob = check_param_type('prob', prob, float) | |||
| self._prob = check_param_type('prob', prob, float) | |||
| class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||
| @@ -511,7 +511,7 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||
| is_targeted=is_targeted, | |||
| norm_level=norm_level, | |||
| loss_fn=loss_fn) | |||
| self.prob = check_param_type('prob', prob, float) | |||
| self._prob = check_param_type('prob', prob, float) | |||
| def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False): | |||
| @@ -22,9 +22,11 @@ from mindspore import Tensor | |||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
| check_int_positive | |||
| from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
| Translate, Scale, Shear, Rotate | |||
| check_param_multi_types, check_norm_level, check_param_in_range | |||
| from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \ | |||
| Noise, Translate, Scale, Shear, Rotate | |||
| from mindarmour.attacks import FastGradientSignMethod, \ | |||
| MomentumDiverseInputIterativeMethod, ProjectedGradientDescent | |||
| class Fuzzer: | |||
| @@ -35,129 +37,280 @@ class Fuzzer: | |||
| Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_ | |||
| Args: | |||
| initial_seeds (list): Initial fuzzing seed, format: [[image, label], | |||
| [image, label], ...]. | |||
| target_model (Model): Target fuzz model. | |||
| train_dataset (numpy.ndarray): Training dataset used for determining | |||
| the neurons' output boundaries. | |||
| const_k (int): The number of mutate tests for a seed. | |||
| mode (str): Image mode used in image transform, 'L' means grey graph. | |||
| Default: 'L'. | |||
| max_seed_num (int): The initial seeds max value. Default: 1000 | |||
| segmented_num (int): The number of segmented sections of neurons' | |||
| output intervals. | |||
| neuron_num (int): The number of testing neurons. | |||
| """ | |||
| def __init__(self, initial_seeds, target_model, train_dataset, const_K, | |||
| mode='L', max_seed_num=1000): | |||
| self.initial_seeds = initial_seeds | |||
| def __init__(self, target_model, train_dataset, segmented_num, neuron_num): | |||
| self.target_model = check_model('model', target_model, Model) | |||
| self.train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
| self.const_k = check_int_positive('const_k', const_K) | |||
| self.mode = mode | |||
| self.max_seed_num = check_int_positive('max_seed_num', max_seed_num) | |||
| self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10, | |||
| train_dataset) | |||
| def _image_value_expand(self, image): | |||
| return image*255 | |||
| def _image_value_compress(self, image): | |||
| return image / 255 | |||
| def _metamorphic_mutate(self, seed, try_num=50): | |||
| if self.mode == 'L': | |||
| seed = seed[0] | |||
| info = [seed, seed] | |||
| mutate_tests = [] | |||
| pixel_value_trans = ['Contrast', 'Brightness', 'Blur', 'Noise'] | |||
| affine_trans = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||
| strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur, | |||
| 'Noise': Noise, | |||
| 'Translate': Translate, 'Scale': Scale, 'Shear': Shear, | |||
| 'Rotate': Rotate} | |||
| for _ in range(self.const_k): | |||
| for _ in range(try_num): | |||
| if (info[0] == info[1]).all(): | |||
| trans_strage = self._random_pick_mutate(affine_trans, | |||
| pixel_value_trans) | |||
| else: | |||
| trans_strage = self._random_pick_mutate(pixel_value_trans, | |||
| []) | |||
| transform = strages[trans_strage]( | |||
| self._image_value_expand(seed), self.mode) | |||
| transform.set_params(auto_param=True) | |||
| mutate_test = transform.transform() | |||
| mutate_test = np.expand_dims( | |||
| self._image_value_compress(mutate_test), 0) | |||
| if self._is_trans_valid(seed, mutate_test): | |||
| if trans_strage in affine_trans: | |||
| info[1] = mutate_test | |||
| mutate_tests.append(mutate_test) | |||
| if not mutate_tests: | |||
| mutate_tests.append(seed) | |||
| return np.array(mutate_tests) | |||
| def fuzzing(self, coverage_metric='KMNC'): | |||
| self.coverage_metrics = ModelCoverageMetrics(target_model, | |||
| segmented_num, | |||
| neuron_num, train_dataset) | |||
| # Allowed mutate strategies so far. | |||
| self.strategies = {'Contrast': Contrast, 'Brightness': Brightness, | |||
| 'Blur': Blur, 'Noise': Noise, 'Translate': Translate, | |||
| 'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, | |||
| 'FGSM': FastGradientSignMethod, | |||
| 'PGD': ProjectedGradientDescent, | |||
| 'MDIIM': MomentumDiverseInputIterativeMethod} | |||
| self.affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||
| self.pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', | |||
| 'Noise'] | |||
| self.attacks_list = ['FGSM', 'PGD', 'MDIIM'] | |||
| self.attack_param_checklists = { | |||
| 'FGSM': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||
| 'alpha': {'dtype': [float, int], | |||
| 'range': [0, 1]}, | |||
| 'bounds': {'dtype': [list, tuple], | |||
| 'range': None}, | |||
| }}, | |||
| 'PGD': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||
| 'eps_iter': {'dtype': [float, int], | |||
| 'range': [0, 1e5]}, | |||
| 'nb_iter': {'dtype': [float, int], | |||
| 'range': [0, 1e5]}, | |||
| 'bounds': {'dtype': [list, tuple], | |||
| 'range': None}, | |||
| }}, | |||
| 'MDIIM': { | |||
| 'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||
| 'norm_level': {'dtype': [str], 'range': None}, | |||
| 'prob': {'dtype': [float, int], 'range': [0, 1]}, | |||
| 'bounds': {'dtype': [list, tuple], 'range': None}, | |||
| }}} | |||
| def _check_attack_params(self, method, params): | |||
| """Check input parameters of attack methods.""" | |||
| allow_params = self.attack_param_checklists[method]['params'].keys() | |||
| for p in params: | |||
| if p not in allow_params: | |||
| msg = "parameters of {} must in {}".format(method, allow_params) | |||
| raise ValueError(msg) | |||
| if p == 'bounds': | |||
| bounds = check_param_multi_types('bounds', params[p], | |||
| [list, tuple]) | |||
| for b in bounds: | |||
| _ = check_param_multi_types('bound', b, [int, float]) | |||
| elif p == 'norm_level': | |||
| _ = check_norm_level(params[p]) | |||
| else: | |||
| allow_type = self.attack_param_checklists[method]['params'][p][ | |||
| 'dtype'] | |||
| allow_range = self.attack_param_checklists[method]['params'][p][ | |||
| 'range'] | |||
| _ = check_param_multi_types(str(p), params[p], allow_type) | |||
| _ = check_param_in_range(str(p), params[p], allow_range[0], | |||
| allow_range[1]) | |||
| def _metamorphic_mutate(self, seed, mutates, mutate_config, | |||
| mutate_num_per_seed): | |||
| """Mutate a seed using strategies random selected from mutate_config.""" | |||
| mutate_samples = [] | |||
| mutate_strategies = [] | |||
| only_pixel_trans = seed[2] | |||
| for _ in range(mutate_num_per_seed): | |||
| strage = choice(mutate_config) | |||
| # Choose a pixel value based transform method | |||
| if only_pixel_trans: | |||
| while strage['method'] not in self.pixel_value_trans_list: | |||
| strage = choice(mutate_config) | |||
| transform = mutates[strage['method']] | |||
| params = strage['params'] | |||
| method = strage['method'] | |||
| if method in list(self.pixel_value_trans_list + self.affine_trans_list): | |||
| transform.set_params(**params) | |||
| mutate_sample = transform.transform(seed[0]) | |||
| else: | |||
| for p in params: | |||
| transform.__setattr__('_'+str(p), params[p]) | |||
| mutate_sample = transform.generate([seed[0].astype(np.float32)], | |||
| [seed[1]])[0] | |||
| if method not in self.pixel_value_trans_list: | |||
| only_pixel_trans = 1 | |||
| mutate_sample = [mutate_sample, seed[1], only_pixel_trans] | |||
| if self._is_trans_valid(seed[0], mutate_sample[0]): | |||
| mutate_samples.append(mutate_sample) | |||
| mutate_strategies.append(method) | |||
| if not mutate_samples: | |||
| mutate_samples.append(seed) | |||
| mutate_strategies.append(None) | |||
| return np.array(mutate_samples), mutate_strategies | |||
| def _init_mutates(self, mutate_config): | |||
| """ Check whether the mutate_config meet the specification.""" | |||
| has_pixel_trans = False | |||
| for mutate in mutate_config: | |||
| if mutate['method'] in self.pixel_value_trans_list: | |||
| has_pixel_trans = True | |||
| break | |||
| if not has_pixel_trans: | |||
| msg = "mutate methods in mutate_config at lease have one in {}".format( | |||
| self.pixel_value_trans_list) | |||
| raise ValueError(msg) | |||
| mutates = {} | |||
| for mutate in mutate_config: | |||
| method = mutate['method'] | |||
| params = mutate['params'] | |||
| if method not in self.attacks_list: | |||
| mutates[method] = self.strategies[method]() | |||
| else: | |||
| self._check_attack_params(method, params) | |||
| network = self.target_model._network | |||
| loss_fn = self.target_model._loss_fn | |||
| mutates[method] = self.strategies[method](network, | |||
| loss_fn=loss_fn) | |||
| return mutates | |||
| def evaluate(self, fuzz_samples, gt_labels, fuzz_preds, | |||
| fuzz_strategies): | |||
| """ | |||
| Evaluate generated fuzzing samples in three dimention: accuracy, | |||
| attack success rate and neural coverage. | |||
| Args: | |||
| fuzz_samples (numpy.ndarray): Generated fuzzing samples according to seeds. | |||
| gt_labels (numpy.ndarray): Ground Truth of seeds. | |||
| fuzz_preds (numpy.ndarray): Predictions of generated fuzz samples. | |||
| fuzz_strategies (numpy.ndarray): Mutate strategies of fuzz samples. | |||
| Returns: | |||
| dict, evaluate metrics include accuarcy, attack success rate | |||
| and neural coverage. | |||
| """ | |||
| gt_labels = np.asarray(gt_labels) | |||
| fuzz_preds = np.asarray(fuzz_preds) | |||
| temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) | |||
| acc = np.sum(temp) / np.size(temp) | |||
| cond = [elem in self.attacks_list for elem in fuzz_strategies] | |||
| temp = temp[cond] | |||
| attack_success_rate = 1 - np.sum(temp) / np.size(temp) | |||
| self.coverage_metrics.calculate_coverage( | |||
| np.array(fuzz_samples).astype(np.float32)) | |||
| kmnc = self.coverage_metrics.get_kmnc() | |||
| nbc = self.coverage_metrics.get_nbc() | |||
| snac = self.coverage_metrics.get_snac() | |||
| metrics = {} | |||
| metrics['Accuracy'] = acc | |||
| metrics['Attack_succrss_rate'] = attack_success_rate | |||
| metrics['Neural_coverage_KMNC'] = kmnc | |||
| metrics['Neural_coverage_NBC'] = nbc | |||
| metrics['Neural_coverage_SNAC'] = snac | |||
| return metrics | |||
| def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', | |||
| eval_metric=True, max_iters=10000, mutate_num_per_seed=20): | |||
| """ | |||
| Fuzzing tests for deep neural networks. | |||
| Args: | |||
| mutate_config (list): Mutate configs. The format is | |||
| [{'method': 'Blur', | |||
| 'params': {'auto_param': True}}, | |||
| {'method': 'Contrast', | |||
| 'params': {'factor': 2}}, | |||
| ...]. The support methods list is in `self.strategies`, | |||
| The params of each method must within the range of changeable | |||
| parameters. | |||
| initial_seeds (numpy.ndarray): Initial seeds used to generate | |||
| mutated samples. | |||
| coverage_metric (str): Model coverage metric of neural networks. | |||
| Default: 'KMNC'. | |||
| eval_metric (bool): Whether to evaluate the generated fuzz samples. | |||
| Default: True. | |||
| max_iters (int): Max number of select a seed to mutate. | |||
| Default: 10000. | |||
| mutate_num_per_seed (int): The number of mutate times for a seed. | |||
| Default: 20. | |||
| Returns: | |||
| list, mutated tests mis-predicted by target DNN model. | |||
| list, mutated samples. | |||
| """ | |||
| seed = self._select_next() | |||
| failed_tests = [] | |||
| seed_num = 0 | |||
| while seed and seed_num < self.max_seed_num: | |||
| mutate_tests = self._metamorphic_mutate(seed[0]) | |||
| coverages, predicts = self._run(mutate_tests, coverage_metric) | |||
| # Check whether the mutate_config meet the specification. | |||
| mutates = self._init_mutates(mutate_config) | |||
| seed, initial_seeds = self._select_next(initial_seeds) | |||
| fuzz_samples = [] | |||
| gt_labels = [] | |||
| fuzz_preds = [] | |||
| fuzz_strategies = [] | |||
| iter_num = 0 | |||
| while initial_seeds and iter_num < max_iters: | |||
| # Mutate a seed. | |||
| mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, | |||
| mutates, | |||
| mutate_config, | |||
| mutate_num_per_seed) | |||
| # Calculate the coverages and predictions of generated samples. | |||
| coverages, predicts = self._run(mutate_samples, coverage_metric) | |||
| coverage_gains = self._coverage_gains(coverages) | |||
| for mutate, cov, res in zip(mutate_tests, coverage_gains, predicts): | |||
| if np.argmax(seed[1]) != np.argmax(res): | |||
| failed_tests.append(mutate) | |||
| continue | |||
| for mutate, cov, pred, strategy in zip(mutate_samples, | |||
| coverage_gains, | |||
| predicts, mutate_strategies): | |||
| fuzz_samples.append(mutate[0]) | |||
| gt_labels.append(mutate[1]) | |||
| fuzz_preds.append(pred) | |||
| fuzz_strategies.append(strategy) | |||
| # if the mutate samples has coverage gains add this samples in | |||
| # the initial seeds to guide new mutates. | |||
| if cov > 0: | |||
| self.initial_seeds.append([mutate, seed[1]]) | |||
| seed = self._select_next() | |||
| seed_num += 1 | |||
| return failed_tests | |||
| initial_seeds.append(mutate) | |||
| seed, initial_seeds = self._select_next(initial_seeds) | |||
| iter_num += 1 | |||
| metrics = None | |||
| if eval_metric: | |||
| metrics = self.evaluate(fuzz_samples, gt_labels, fuzz_preds, | |||
| fuzz_strategies) | |||
| return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics | |||
| def _coverage_gains(self, coverages): | |||
| """ Calculate the coverage gains of mutated samples. """ | |||
| gains = [0] + coverages[:-1] | |||
| gains = np.array(coverages) - np.array(gains) | |||
| return gains | |||
| def _run(self, mutate_tests, coverage_metric="KNMC"): | |||
| def _run(self, mutate_samples, coverage_metric="KNMC"): | |||
| """ Calculate the coverages and predictions of generated samples.""" | |||
| samples = [s[0] for s in mutate_samples] | |||
| samples = np.array(samples) | |||
| coverages = [] | |||
| result = self.target_model.predict( | |||
| Tensor(mutate_tests.astype(np.float32))) | |||
| result = result.asnumpy() | |||
| for index in range(len(mutate_tests)): | |||
| mutate = np.expand_dims(mutate_tests[index], 0) | |||
| self.coverage_metrics.model_coverage_test( | |||
| mutate.astype(np.float32), batch_size=1) | |||
| predictions = self.target_model.predict(Tensor(samples.astype(np.float32))) | |||
| predictions = predictions.asnumpy() | |||
| for index in range(len(samples)): | |||
| mutate = samples[:index + 1] | |||
| self.coverage_metrics.calculate_coverage(mutate.astype(np.float32)) | |||
| if coverage_metric == "KMNC": | |||
| coverages.append(self.coverage_metrics.get_kmnc()) | |||
| if coverage_metric == 'NBC': | |||
| coverages.append(self.coverage_metrics.get_nbc()) | |||
| if coverage_metric == 'SNAC': | |||
| coverages.append(self.coverage_metrics.get_snac()) | |||
| return coverages, predictions | |||
| return coverages, result | |||
| def _select_next(self): | |||
| seed = choice(self.initial_seeds) | |||
| return seed | |||
| def _select_next(self, initial_seeds): | |||
| """Randomly select a seed from `initial_seeds`.""" | |||
| seed_num = choice(range(len(initial_seeds))) | |||
| seed = initial_seeds[seed_num] | |||
| del initial_seeds[seed_num] | |||
| return seed, initial_seeds | |||
| def _random_pick_mutate(self, affine_trans_list, pixel_value_trans_list): | |||
| strage = choice(affine_trans_list + pixel_value_trans_list) | |||
| return strage | |||
| def _is_trans_valid(self, seed, mutate_test): | |||
| def _is_trans_valid(self, seed, mutate_sample): | |||
| """ Check a mutated sample is valid. If the number of changed pixels in | |||
| a seed is less than pixels_change_rate*size(seed), this mutate is valid. | |||
| Else check the infinite norm of seed changes, if the value of the | |||
| infinite norm less than pixel_value_change_rate*255, this mutate is | |||
| valid too. Otherwise the opposite.""" | |||
| is_valid = False | |||
| pixels_change_rate = 0.02 | |||
| pixel_value_change_rate = 0.2 | |||
| diff = np.array(seed - mutate_test).flatten() | |||
| diff = np.array(seed - mutate_sample).flatten() | |||
| size = np.shape(diff)[0] | |||
| l0 = np.linalg.norm(diff, ord=0) | |||
| linf = np.linalg.norm(diff, ord=np.inf) | |||
| @@ -167,5 +320,4 @@ class Fuzzer: | |||
| else: | |||
| if linf < pixel_value_change_rate*255: | |||
| is_valid = True | |||
| return is_valid | |||
| @@ -88,7 +88,8 @@ def is_rgb(img): | |||
| Bool, True if input is RGB. | |||
| """ | |||
| if is_numpy(img): | |||
| if len(np.shape(img)) == 3: | |||
| img_shape = np.shape(img) | |||
| if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3): | |||
| return True | |||
| return False | |||
| raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||
| @@ -127,6 +128,7 @@ class ImageTransform: | |||
| of the image is not normalized , it will be normalized between 0 to 1.""" | |||
| rgb = is_rgb(image) | |||
| chw = False | |||
| gray3dim = False | |||
| normalized = is_normalized(image) | |||
| if rgb: | |||
| chw = is_chw(image) | |||
| @@ -135,12 +137,16 @@ class ImageTransform: | |||
| else: | |||
| image = image | |||
| else: | |||
| image = image | |||
| if len(np.shape(image)) == 3: | |||
| gray3dim = True | |||
| image = image[0] | |||
| else: | |||
| image = image | |||
| if normalized: | |||
| image = np.uint8(image*255) | |||
| return rgb, chw, normalized, image | |||
| return rgb, chw, normalized, gray3dim, image | |||
| def _original_format(self, image, chw, normalized): | |||
| def _original_format(self, image, chw, normalized, gray3dim): | |||
| """ Return transformed image with original format. """ | |||
| if not is_numpy(image): | |||
| image = np.array(image) | |||
| @@ -148,6 +154,8 @@ class ImageTransform: | |||
| image = hwc_to_chw(image) | |||
| if normalized: | |||
| image = image / 255 | |||
| if gray3dim: | |||
| image = np.expand_dims(image, 0) | |||
| return image | |||
| def transform(self, image): | |||
| @@ -191,11 +199,12 @@ class Contrast(ImageTransform): | |||
| Returns: | |||
| numpy.ndarray, transformed image. | |||
| """ | |||
| _, chw, normalized, image = self._check(image) | |||
| _, chw, normalized, gray3dim, image = self._check(image) | |||
| image = to_pil(image) | |||
| img_contrast = ImageEnhance.Contrast(image) | |||
| trans_image = img_contrast.enhance(self.factor) | |||
| trans_image = self._original_format(trans_image, chw, normalized) | |||
| trans_image = self._original_format(trans_image, chw, normalized, | |||
| gray3dim) | |||
| return trans_image | |||
| @@ -237,11 +246,12 @@ class Brightness(ImageTransform): | |||
| Returns: | |||
| numpy.ndarray, transformed image. | |||
| """ | |||
| _, chw, normalized, image = self._check(image) | |||
| _, chw, normalized, gray3dim, image = self._check(image) | |||
| image = to_pil(image) | |||
| img_contrast = ImageEnhance.Brightness(image) | |||
| trans_image = img_contrast.enhance(self.factor) | |||
| trans_image = self._original_format(trans_image, chw, normalized) | |||
| trans_image = self._original_format(trans_image, chw, normalized, | |||
| gray3dim) | |||
| return trans_image | |||
| @@ -280,10 +290,11 @@ class Blur(ImageTransform): | |||
| Returns: | |||
| numpy.ndarray, transformed image. | |||
| """ | |||
| _, chw, normalized, image = self._check(image) | |||
| _, chw, normalized, gray3dim, image = self._check(image) | |||
| image = to_pil(image) | |||
| trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||
| trans_image = self._original_format(trans_image, chw, normalized) | |||
| trans_image = self._original_format(trans_image, chw, normalized, | |||
| gray3dim) | |||
| return trans_image | |||
| @@ -324,12 +335,13 @@ class Noise(ImageTransform): | |||
| Returns: | |||
| numpy.ndarray, transformed image. | |||
| """ | |||
| _, chw, normalized, image = self._check(image) | |||
| _, chw, normalized, gray3dim, image = self._check(image) | |||
| noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) | |||
| trans_image = np.copy(image) | |||
| trans_image[noise < -self.factor] = 0 | |||
| trans_image[noise > self.factor] = 1 | |||
| trans_image = self._original_format(trans_image, chw, normalized) | |||
| trans_image = self._original_format(trans_image, chw, normalized, | |||
| gray3dim) | |||
| return trans_image | |||
| @@ -375,7 +387,7 @@ class Translate(ImageTransform): | |||
| Returns: | |||
| numpy.ndarray, transformed image. | |||
| """ | |||
| _, chw, normalized, image = self._check(image) | |||
| _, chw, normalized, gray3dim, image = self._check(image) | |||
| img = to_pil(image) | |||
| if self.auto_param: | |||
| image_shape = np.shape(image) | |||
| @@ -383,7 +395,8 @@ class Translate(ImageTransform): | |||
| self.y_bias = image_shape[1]*self.y_bias | |||
| trans_image = img.transform(img.size, Image.AFFINE, | |||
| (1, 0, self.x_bias, 0, 1, self.y_bias)) | |||
| trans_image = self._original_format(trans_image, chw, normalized) | |||
| trans_image = self._original_format(trans_image, chw, normalized, | |||
| gray3dim) | |||
| return trans_image | |||
| @@ -431,7 +444,7 @@ class Scale(ImageTransform): | |||
| Returns: | |||
| numpy.ndarray, transformed image. | |||
| """ | |||
| rgb, chw, normalized, image = self._check(image) | |||
| rgb, chw, normalized, gray3dim, image = self._check(image) | |||
| if rgb: | |||
| h, w, _ = np.shape(image) | |||
| else: | |||
| @@ -442,7 +455,8 @@ class Scale(ImageTransform): | |||
| trans_image = img.transform(img.size, Image.AFFINE, | |||
| (self.factor_x, 0, move_x_centor, | |||
| 0, self.factor_y, move_y_centor)) | |||
| trans_image = self._original_format(trans_image, chw, normalized) | |||
| trans_image = self._original_format(trans_image, chw, normalized, | |||
| gray3dim) | |||
| return trans_image | |||
| @@ -500,7 +514,7 @@ class Shear(ImageTransform): | |||
| Returns: | |||
| numpy.ndarray, transformed image. | |||
| """ | |||
| rgb, chw, normalized, image = self._check(image) | |||
| rgb, chw, normalized, gray3dim, image = self._check(image) | |||
| img = to_pil(image) | |||
| if rgb: | |||
| h, w, _ = np.shape(image) | |||
| @@ -523,7 +537,8 @@ class Shear(ImageTransform): | |||
| trans_image = img.transform(img.size, Image.AFFINE, | |||
| (scale, scale*self.factor_x, move_x_cen, | |||
| scale*self.factor_y, scale, move_y_cen)) | |||
| trans_image = self._original_format(trans_image, chw, normalized) | |||
| trans_image = self._original_format(trans_image, chw, normalized, | |||
| gray3dim) | |||
| return trans_image | |||
| @@ -562,8 +577,9 @@ class Rotate(ImageTransform): | |||
| Returns: | |||
| numpy.ndarray, transformed image. | |||
| """ | |||
| _, chw, normalized, image = self._check(image) | |||
| _, chw, normalized, gray3dim, image = self._check(image) | |||
| img = to_pil(image) | |||
| trans_image = img.rotate(self.angle, expand=True) | |||
| trans_image = self._original_format(trans_image, chw, normalized) | |||
| trans_image = self._original_format(trans_image, chw, normalized, | |||
| gray3dim) | |||
| return trans_image | |||
| @@ -0,0 +1,172 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Model-fuzz coverage test. | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import context | |||
| from mindspore import nn | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train import Model | |||
| from mindarmour.fuzzing.fuzzing import Fuzzer | |||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
| from mindarmour.utils.logger import LogUtil | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Fuzzing test' | |||
| LOGGER.set_level('INFO') | |||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
| weight = weight_variable() | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||
| def fc_with_initialize(input_channels, out_channels): | |||
| weight = weight_variable() | |||
| bias = weight_variable() | |||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||
| def weight_variable(): | |||
| return TruncatedNormal(0.02) | |||
| class Net(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.conv1 = conv(1, 6, 5) | |||
| self.conv2 = conv(6, 16, 5) | |||
| self.fc1 = fc_with_initialize(16*5*5, 120) | |||
| self.fc2 = fc_with_initialize(120, 84) | |||
| self.fc3 = fc_with_initialize(84, 10) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.reshape = P.Reshape() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.reshape(x, (-1, 16*5*5)) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_fuzzing_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| # load network | |||
| net = Net() | |||
| model = Model(net) | |||
| batch_size = 8 | |||
| num_classe = 10 | |||
| mutate_config = [{'method': 'Blur', | |||
| 'params': {'auto_param': True}}, | |||
| {'method': 'Contrast', | |||
| 'params': {'factor': 2}}, | |||
| {'method': 'Translate', | |||
| 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||
| {'method': 'FGSM', | |||
| 'params': {'eps': 0.1, 'alpha': 0.1}} | |||
| ] | |||
| # initialize fuzz test with training dataset | |||
| train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
| test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
| test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
| initial_seeds = [] | |||
| # make initial seeds | |||
| for img, label in zip(test_images, test_labels): | |||
| initial_seeds.append([img, label, 0]) | |||
| initial_seeds = initial_seeds[:100] | |||
| model_coverage_test.calculate_coverage( | |||
| np.array(test_images[:100]).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||
| print(metrics) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_fuzzing_cpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| # load network | |||
| net = Net() | |||
| model = Model(net) | |||
| batch_size = 8 | |||
| num_classe = 10 | |||
| mutate_config = [{'method': 'Blur', | |||
| 'params': {'auto_param': True}}, | |||
| {'method': 'Contrast', | |||
| 'params': {'factor': 2}}, | |||
| {'method': 'Translate', | |||
| 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||
| {'method': 'FGSM', | |||
| 'params': {'eps': 0.1, 'alpha': 0.1}} | |||
| ] | |||
| # initialize fuzz test with training dataset | |||
| train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
| test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
| test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
| initial_seeds = [] | |||
| # make initial seeds | |||
| for img, label in zip(test_images, test_labels): | |||
| initial_seeds.append([img, label, 0]) | |||
| initial_seeds = initial_seeds[:100] | |||
| model_coverage_test.calculate_coverage( | |||
| np.array(test_images[:100]).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||
| print(metrics) | |||