| @@ -19,7 +19,7 @@ from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from lenet5_net import LeNet5 | from lenet5_net import LeNet5 | ||||
| from mindarmour.fuzzing.fuzzing import Fuzzing | |||||
| from mindarmour.fuzzing.fuzzing import Fuzzer | |||||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | ||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| @@ -38,11 +38,20 @@ def test_lenet_mnist_fuzzing(): | |||||
| load_dict = load_checkpoint(ckpt_name) | load_dict = load_checkpoint(ckpt_name) | ||||
| load_param_into_net(net, load_dict) | load_param_into_net(net, load_dict) | ||||
| model = Model(net) | model = Model(net) | ||||
| mutate_config = [{'method': 'Blur', | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'Contrast', | |||||
| 'params': {'factor': 2}}, | |||||
| {'method': 'Translate', | |||||
| 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||||
| {'method': 'FGSM', | |||||
| 'params': {'eps': 0.1, 'alpha': 0.1}} | |||||
| ] | |||||
| # get training data | # get training data | ||||
| data_list = "./MNIST_unzip/train" | data_list = "./MNIST_unzip/train" | ||||
| batch_size = 32 | batch_size = 32 | ||||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=False) | |||||
| train_images = [] | train_images = [] | ||||
| for data in ds.create_tuple_iterator(): | for data in ds.create_tuple_iterator(): | ||||
| images = data[0].astype(np.float32) | images = data[0].astype(np.float32) | ||||
| @@ -56,7 +65,7 @@ def test_lenet_mnist_fuzzing(): | |||||
| # get test data | # get test data | ||||
| data_list = "./MNIST_unzip/test" | data_list = "./MNIST_unzip/test" | ||||
| batch_size = 32 | batch_size = 32 | ||||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=False) | |||||
| test_images = [] | test_images = [] | ||||
| test_labels = [] | test_labels = [] | ||||
| for data in ds.create_tuple_iterator(): | for data in ds.create_tuple_iterator(): | ||||
| @@ -70,19 +79,20 @@ def test_lenet_mnist_fuzzing(): | |||||
| # make initial seeds | # make initial seeds | ||||
| for img, label in zip(test_images, test_labels): | for img, label in zip(test_images, test_labels): | ||||
| initial_seeds.append([img, label]) | |||||
| initial_seeds.append([img, label, 0]) | |||||
| initial_seeds = initial_seeds[:100] | initial_seeds = initial_seeds[:100] | ||||
| model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32)) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||||
| model_coverage_test.calculate_coverage( | |||||
| np.array(test_images[:100]).astype(np.float32)) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||||
| model_coverage_test.get_kmnc()) | |||||
| model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20) | |||||
| failed_tests = model_fuzz_test.fuzzing() | |||||
| if failed_tests: | |||||
| model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||||
| else: | |||||
| LOGGER.info(TAG, 'Fuzzing test identifies none failed test') | |||||
| model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, | |||||
| eval_metric=True) | |||||
| if metrics: | |||||
| for key in metrics: | |||||
| LOGGER.info(TAG, key + ': %s', metrics[key]) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -227,8 +227,8 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
| clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
| clip_diff = clip_max - clip_min | clip_diff = clip_max - clip_min | ||||
| for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
| if 'self.prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self.prob) | |||||
| if 'self._prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self._prob) | |||||
| else: | else: | ||||
| d_inputs = inputs | d_inputs = inputs | ||||
| adv_x = self._attack.generate(d_inputs, labels) | adv_x = self._attack.generate(d_inputs, labels) | ||||
| @@ -238,8 +238,8 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
| inputs = adv_x | inputs = adv_x | ||||
| else: | else: | ||||
| for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
| if 'self.prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self.prob) | |||||
| if 'self._prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self._prob) | |||||
| else: | else: | ||||
| d_inputs = inputs | d_inputs = inputs | ||||
| adv_x = self._attack.generate(d_inputs, labels) | adv_x = self._attack.generate(d_inputs, labels) | ||||
| @@ -311,8 +311,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
| clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
| clip_diff = clip_max - clip_min | clip_diff = clip_max - clip_min | ||||
| for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
| if 'self.prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self.prob) | |||||
| if 'self._prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self._prob) | |||||
| else: | else: | ||||
| d_inputs = inputs | d_inputs = inputs | ||||
| gradient = self._gradient(d_inputs, labels) | gradient = self._gradient(d_inputs, labels) | ||||
| @@ -325,8 +325,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
| inputs = adv_x | inputs = adv_x | ||||
| else: | else: | ||||
| for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
| if 'self.prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self.prob) | |||||
| if 'self._prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self._prob) | |||||
| else: | else: | ||||
| d_inputs = inputs | d_inputs = inputs | ||||
| gradient = self._gradient(d_inputs, labels) | gradient = self._gradient(d_inputs, labels) | ||||
| @@ -476,7 +476,7 @@ class DiverseInputIterativeMethod(BasicIterativeMethod): | |||||
| is_targeted=is_targeted, | is_targeted=is_targeted, | ||||
| nb_iter=nb_iter, | nb_iter=nb_iter, | ||||
| loss_fn=loss_fn) | loss_fn=loss_fn) | ||||
| self.prob = check_param_type('prob', prob, float) | |||||
| self._prob = check_param_type('prob', prob, float) | |||||
| class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | ||||
| @@ -511,7 +511,7 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||||
| is_targeted=is_targeted, | is_targeted=is_targeted, | ||||
| norm_level=norm_level, | norm_level=norm_level, | ||||
| loss_fn=loss_fn) | loss_fn=loss_fn) | ||||
| self.prob = check_param_type('prob', prob, float) | |||||
| self._prob = check_param_type('prob', prob, float) | |||||
| def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False): | def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False): | ||||
| @@ -22,9 +22,11 @@ from mindspore import Tensor | |||||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | ||||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, \ | ||||
| check_int_positive | |||||
| from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \ | |||||
| Translate, Scale, Shear, Rotate | |||||
| check_param_multi_types, check_norm_level, check_param_in_range | |||||
| from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \ | |||||
| Noise, Translate, Scale, Shear, Rotate | |||||
| from mindarmour.attacks import FastGradientSignMethod, \ | |||||
| MomentumDiverseInputIterativeMethod, ProjectedGradientDescent | |||||
| class Fuzzer: | class Fuzzer: | ||||
| @@ -35,129 +37,280 @@ class Fuzzer: | |||||
| Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_ | Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_ | ||||
| Args: | Args: | ||||
| initial_seeds (list): Initial fuzzing seed, format: [[image, label], | |||||
| [image, label], ...]. | |||||
| target_model (Model): Target fuzz model. | target_model (Model): Target fuzz model. | ||||
| train_dataset (numpy.ndarray): Training dataset used for determining | train_dataset (numpy.ndarray): Training dataset used for determining | ||||
| the neurons' output boundaries. | the neurons' output boundaries. | ||||
| const_k (int): The number of mutate tests for a seed. | |||||
| mode (str): Image mode used in image transform, 'L' means grey graph. | |||||
| Default: 'L'. | |||||
| max_seed_num (int): The initial seeds max value. Default: 1000 | |||||
| segmented_num (int): The number of segmented sections of neurons' | |||||
| output intervals. | |||||
| neuron_num (int): The number of testing neurons. | |||||
| """ | """ | ||||
| def __init__(self, initial_seeds, target_model, train_dataset, const_K, | |||||
| mode='L', max_seed_num=1000): | |||||
| self.initial_seeds = initial_seeds | |||||
| def __init__(self, target_model, train_dataset, segmented_num, neuron_num): | |||||
| self.target_model = check_model('model', target_model, Model) | self.target_model = check_model('model', target_model, Model) | ||||
| self.train_dataset = check_numpy_param('train_dataset', train_dataset) | self.train_dataset = check_numpy_param('train_dataset', train_dataset) | ||||
| self.const_k = check_int_positive('const_k', const_K) | |||||
| self.mode = mode | |||||
| self.max_seed_num = check_int_positive('max_seed_num', max_seed_num) | |||||
| self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10, | |||||
| train_dataset) | |||||
| def _image_value_expand(self, image): | |||||
| return image*255 | |||||
| def _image_value_compress(self, image): | |||||
| return image / 255 | |||||
| def _metamorphic_mutate(self, seed, try_num=50): | |||||
| if self.mode == 'L': | |||||
| seed = seed[0] | |||||
| info = [seed, seed] | |||||
| mutate_tests = [] | |||||
| pixel_value_trans = ['Contrast', 'Brightness', 'Blur', 'Noise'] | |||||
| affine_trans = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||||
| strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur, | |||||
| 'Noise': Noise, | |||||
| 'Translate': Translate, 'Scale': Scale, 'Shear': Shear, | |||||
| 'Rotate': Rotate} | |||||
| for _ in range(self.const_k): | |||||
| for _ in range(try_num): | |||||
| if (info[0] == info[1]).all(): | |||||
| trans_strage = self._random_pick_mutate(affine_trans, | |||||
| pixel_value_trans) | |||||
| else: | |||||
| trans_strage = self._random_pick_mutate(pixel_value_trans, | |||||
| []) | |||||
| transform = strages[trans_strage]( | |||||
| self._image_value_expand(seed), self.mode) | |||||
| transform.set_params(auto_param=True) | |||||
| mutate_test = transform.transform() | |||||
| mutate_test = np.expand_dims( | |||||
| self._image_value_compress(mutate_test), 0) | |||||
| if self._is_trans_valid(seed, mutate_test): | |||||
| if trans_strage in affine_trans: | |||||
| info[1] = mutate_test | |||||
| mutate_tests.append(mutate_test) | |||||
| if not mutate_tests: | |||||
| mutate_tests.append(seed) | |||||
| return np.array(mutate_tests) | |||||
| def fuzzing(self, coverage_metric='KMNC'): | |||||
| self.coverage_metrics = ModelCoverageMetrics(target_model, | |||||
| segmented_num, | |||||
| neuron_num, train_dataset) | |||||
| # Allowed mutate strategies so far. | |||||
| self.strategies = {'Contrast': Contrast, 'Brightness': Brightness, | |||||
| 'Blur': Blur, 'Noise': Noise, 'Translate': Translate, | |||||
| 'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, | |||||
| 'FGSM': FastGradientSignMethod, | |||||
| 'PGD': ProjectedGradientDescent, | |||||
| 'MDIIM': MomentumDiverseInputIterativeMethod} | |||||
| self.affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||||
| self.pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', | |||||
| 'Noise'] | |||||
| self.attacks_list = ['FGSM', 'PGD', 'MDIIM'] | |||||
| self.attack_param_checklists = { | |||||
| 'FGSM': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||||
| 'alpha': {'dtype': [float, int], | |||||
| 'range': [0, 1]}, | |||||
| 'bounds': {'dtype': [list, tuple], | |||||
| 'range': None}, | |||||
| }}, | |||||
| 'PGD': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||||
| 'eps_iter': {'dtype': [float, int], | |||||
| 'range': [0, 1e5]}, | |||||
| 'nb_iter': {'dtype': [float, int], | |||||
| 'range': [0, 1e5]}, | |||||
| 'bounds': {'dtype': [list, tuple], | |||||
| 'range': None}, | |||||
| }}, | |||||
| 'MDIIM': { | |||||
| 'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||||
| 'norm_level': {'dtype': [str], 'range': None}, | |||||
| 'prob': {'dtype': [float, int], 'range': [0, 1]}, | |||||
| 'bounds': {'dtype': [list, tuple], 'range': None}, | |||||
| }}} | |||||
| def _check_attack_params(self, method, params): | |||||
| """Check input parameters of attack methods.""" | |||||
| allow_params = self.attack_param_checklists[method]['params'].keys() | |||||
| for p in params: | |||||
| if p not in allow_params: | |||||
| msg = "parameters of {} must in {}".format(method, allow_params) | |||||
| raise ValueError(msg) | |||||
| if p == 'bounds': | |||||
| bounds = check_param_multi_types('bounds', params[p], | |||||
| [list, tuple]) | |||||
| for b in bounds: | |||||
| _ = check_param_multi_types('bound', b, [int, float]) | |||||
| elif p == 'norm_level': | |||||
| _ = check_norm_level(params[p]) | |||||
| else: | |||||
| allow_type = self.attack_param_checklists[method]['params'][p][ | |||||
| 'dtype'] | |||||
| allow_range = self.attack_param_checklists[method]['params'][p][ | |||||
| 'range'] | |||||
| _ = check_param_multi_types(str(p), params[p], allow_type) | |||||
| _ = check_param_in_range(str(p), params[p], allow_range[0], | |||||
| allow_range[1]) | |||||
| def _metamorphic_mutate(self, seed, mutates, mutate_config, | |||||
| mutate_num_per_seed): | |||||
| """Mutate a seed using strategies random selected from mutate_config.""" | |||||
| mutate_samples = [] | |||||
| mutate_strategies = [] | |||||
| only_pixel_trans = seed[2] | |||||
| for _ in range(mutate_num_per_seed): | |||||
| strage = choice(mutate_config) | |||||
| # Choose a pixel value based transform method | |||||
| if only_pixel_trans: | |||||
| while strage['method'] not in self.pixel_value_trans_list: | |||||
| strage = choice(mutate_config) | |||||
| transform = mutates[strage['method']] | |||||
| params = strage['params'] | |||||
| method = strage['method'] | |||||
| if method in list(self.pixel_value_trans_list + self.affine_trans_list): | |||||
| transform.set_params(**params) | |||||
| mutate_sample = transform.transform(seed[0]) | |||||
| else: | |||||
| for p in params: | |||||
| transform.__setattr__('_'+str(p), params[p]) | |||||
| mutate_sample = transform.generate([seed[0].astype(np.float32)], | |||||
| [seed[1]])[0] | |||||
| if method not in self.pixel_value_trans_list: | |||||
| only_pixel_trans = 1 | |||||
| mutate_sample = [mutate_sample, seed[1], only_pixel_trans] | |||||
| if self._is_trans_valid(seed[0], mutate_sample[0]): | |||||
| mutate_samples.append(mutate_sample) | |||||
| mutate_strategies.append(method) | |||||
| if not mutate_samples: | |||||
| mutate_samples.append(seed) | |||||
| mutate_strategies.append(None) | |||||
| return np.array(mutate_samples), mutate_strategies | |||||
| def _init_mutates(self, mutate_config): | |||||
| """ Check whether the mutate_config meet the specification.""" | |||||
| has_pixel_trans = False | |||||
| for mutate in mutate_config: | |||||
| if mutate['method'] in self.pixel_value_trans_list: | |||||
| has_pixel_trans = True | |||||
| break | |||||
| if not has_pixel_trans: | |||||
| msg = "mutate methods in mutate_config at lease have one in {}".format( | |||||
| self.pixel_value_trans_list) | |||||
| raise ValueError(msg) | |||||
| mutates = {} | |||||
| for mutate in mutate_config: | |||||
| method = mutate['method'] | |||||
| params = mutate['params'] | |||||
| if method not in self.attacks_list: | |||||
| mutates[method] = self.strategies[method]() | |||||
| else: | |||||
| self._check_attack_params(method, params) | |||||
| network = self.target_model._network | |||||
| loss_fn = self.target_model._loss_fn | |||||
| mutates[method] = self.strategies[method](network, | |||||
| loss_fn=loss_fn) | |||||
| return mutates | |||||
| def evaluate(self, fuzz_samples, gt_labels, fuzz_preds, | |||||
| fuzz_strategies): | |||||
| """ | |||||
| Evaluate generated fuzzing samples in three dimention: accuracy, | |||||
| attack success rate and neural coverage. | |||||
| Args: | |||||
| fuzz_samples (numpy.ndarray): Generated fuzzing samples according to seeds. | |||||
| gt_labels (numpy.ndarray): Ground Truth of seeds. | |||||
| fuzz_preds (numpy.ndarray): Predictions of generated fuzz samples. | |||||
| fuzz_strategies (numpy.ndarray): Mutate strategies of fuzz samples. | |||||
| Returns: | |||||
| dict, evaluate metrics include accuarcy, attack success rate | |||||
| and neural coverage. | |||||
| """ | |||||
| gt_labels = np.asarray(gt_labels) | |||||
| fuzz_preds = np.asarray(fuzz_preds) | |||||
| temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) | |||||
| acc = np.sum(temp) / np.size(temp) | |||||
| cond = [elem in self.attacks_list for elem in fuzz_strategies] | |||||
| temp = temp[cond] | |||||
| attack_success_rate = 1 - np.sum(temp) / np.size(temp) | |||||
| self.coverage_metrics.calculate_coverage( | |||||
| np.array(fuzz_samples).astype(np.float32)) | |||||
| kmnc = self.coverage_metrics.get_kmnc() | |||||
| nbc = self.coverage_metrics.get_nbc() | |||||
| snac = self.coverage_metrics.get_snac() | |||||
| metrics = {} | |||||
| metrics['Accuracy'] = acc | |||||
| metrics['Attack_succrss_rate'] = attack_success_rate | |||||
| metrics['Neural_coverage_KMNC'] = kmnc | |||||
| metrics['Neural_coverage_NBC'] = nbc | |||||
| metrics['Neural_coverage_SNAC'] = snac | |||||
| return metrics | |||||
| def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', | |||||
| eval_metric=True, max_iters=10000, mutate_num_per_seed=20): | |||||
| """ | """ | ||||
| Fuzzing tests for deep neural networks. | Fuzzing tests for deep neural networks. | ||||
| Args: | Args: | ||||
| mutate_config (list): Mutate configs. The format is | |||||
| [{'method': 'Blur', | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'Contrast', | |||||
| 'params': {'factor': 2}}, | |||||
| ...]. The support methods list is in `self.strategies`, | |||||
| The params of each method must within the range of changeable | |||||
| parameters. | |||||
| initial_seeds (numpy.ndarray): Initial seeds used to generate | |||||
| mutated samples. | |||||
| coverage_metric (str): Model coverage metric of neural networks. | coverage_metric (str): Model coverage metric of neural networks. | ||||
| Default: 'KMNC'. | Default: 'KMNC'. | ||||
| eval_metric (bool): Whether to evaluate the generated fuzz samples. | |||||
| Default: True. | |||||
| max_iters (int): Max number of select a seed to mutate. | |||||
| Default: 10000. | |||||
| mutate_num_per_seed (int): The number of mutate times for a seed. | |||||
| Default: 20. | |||||
| Returns: | Returns: | ||||
| list, mutated tests mis-predicted by target DNN model. | |||||
| list, mutated samples. | |||||
| """ | """ | ||||
| seed = self._select_next() | |||||
| failed_tests = [] | |||||
| seed_num = 0 | |||||
| while seed and seed_num < self.max_seed_num: | |||||
| mutate_tests = self._metamorphic_mutate(seed[0]) | |||||
| coverages, predicts = self._run(mutate_tests, coverage_metric) | |||||
| # Check whether the mutate_config meet the specification. | |||||
| mutates = self._init_mutates(mutate_config) | |||||
| seed, initial_seeds = self._select_next(initial_seeds) | |||||
| fuzz_samples = [] | |||||
| gt_labels = [] | |||||
| fuzz_preds = [] | |||||
| fuzz_strategies = [] | |||||
| iter_num = 0 | |||||
| while initial_seeds and iter_num < max_iters: | |||||
| # Mutate a seed. | |||||
| mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, | |||||
| mutates, | |||||
| mutate_config, | |||||
| mutate_num_per_seed) | |||||
| # Calculate the coverages and predictions of generated samples. | |||||
| coverages, predicts = self._run(mutate_samples, coverage_metric) | |||||
| coverage_gains = self._coverage_gains(coverages) | coverage_gains = self._coverage_gains(coverages) | ||||
| for mutate, cov, res in zip(mutate_tests, coverage_gains, predicts): | |||||
| if np.argmax(seed[1]) != np.argmax(res): | |||||
| failed_tests.append(mutate) | |||||
| continue | |||||
| for mutate, cov, pred, strategy in zip(mutate_samples, | |||||
| coverage_gains, | |||||
| predicts, mutate_strategies): | |||||
| fuzz_samples.append(mutate[0]) | |||||
| gt_labels.append(mutate[1]) | |||||
| fuzz_preds.append(pred) | |||||
| fuzz_strategies.append(strategy) | |||||
| # if the mutate samples has coverage gains add this samples in | |||||
| # the initial seeds to guide new mutates. | |||||
| if cov > 0: | if cov > 0: | ||||
| self.initial_seeds.append([mutate, seed[1]]) | |||||
| seed = self._select_next() | |||||
| seed_num += 1 | |||||
| return failed_tests | |||||
| initial_seeds.append(mutate) | |||||
| seed, initial_seeds = self._select_next(initial_seeds) | |||||
| iter_num += 1 | |||||
| metrics = None | |||||
| if eval_metric: | |||||
| metrics = self.evaluate(fuzz_samples, gt_labels, fuzz_preds, | |||||
| fuzz_strategies) | |||||
| return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics | |||||
| def _coverage_gains(self, coverages): | def _coverage_gains(self, coverages): | ||||
| """ Calculate the coverage gains of mutated samples. """ | |||||
| gains = [0] + coverages[:-1] | gains = [0] + coverages[:-1] | ||||
| gains = np.array(coverages) - np.array(gains) | gains = np.array(coverages) - np.array(gains) | ||||
| return gains | return gains | ||||
| def _run(self, mutate_tests, coverage_metric="KNMC"): | |||||
| def _run(self, mutate_samples, coverage_metric="KNMC"): | |||||
| """ Calculate the coverages and predictions of generated samples.""" | |||||
| samples = [s[0] for s in mutate_samples] | |||||
| samples = np.array(samples) | |||||
| coverages = [] | coverages = [] | ||||
| result = self.target_model.predict( | |||||
| Tensor(mutate_tests.astype(np.float32))) | |||||
| result = result.asnumpy() | |||||
| for index in range(len(mutate_tests)): | |||||
| mutate = np.expand_dims(mutate_tests[index], 0) | |||||
| self.coverage_metrics.model_coverage_test( | |||||
| mutate.astype(np.float32), batch_size=1) | |||||
| predictions = self.target_model.predict(Tensor(samples.astype(np.float32))) | |||||
| predictions = predictions.asnumpy() | |||||
| for index in range(len(samples)): | |||||
| mutate = samples[:index + 1] | |||||
| self.coverage_metrics.calculate_coverage(mutate.astype(np.float32)) | |||||
| if coverage_metric == "KMNC": | if coverage_metric == "KMNC": | ||||
| coverages.append(self.coverage_metrics.get_kmnc()) | coverages.append(self.coverage_metrics.get_kmnc()) | ||||
| if coverage_metric == 'NBC': | |||||
| coverages.append(self.coverage_metrics.get_nbc()) | |||||
| if coverage_metric == 'SNAC': | |||||
| coverages.append(self.coverage_metrics.get_snac()) | |||||
| return coverages, predictions | |||||
| return coverages, result | |||||
| def _select_next(self): | |||||
| seed = choice(self.initial_seeds) | |||||
| return seed | |||||
| def _select_next(self, initial_seeds): | |||||
| """Randomly select a seed from `initial_seeds`.""" | |||||
| seed_num = choice(range(len(initial_seeds))) | |||||
| seed = initial_seeds[seed_num] | |||||
| del initial_seeds[seed_num] | |||||
| return seed, initial_seeds | |||||
| def _random_pick_mutate(self, affine_trans_list, pixel_value_trans_list): | |||||
| strage = choice(affine_trans_list + pixel_value_trans_list) | |||||
| return strage | |||||
| def _is_trans_valid(self, seed, mutate_test): | |||||
| def _is_trans_valid(self, seed, mutate_sample): | |||||
| """ Check a mutated sample is valid. If the number of changed pixels in | |||||
| a seed is less than pixels_change_rate*size(seed), this mutate is valid. | |||||
| Else check the infinite norm of seed changes, if the value of the | |||||
| infinite norm less than pixel_value_change_rate*255, this mutate is | |||||
| valid too. Otherwise the opposite.""" | |||||
| is_valid = False | is_valid = False | ||||
| pixels_change_rate = 0.02 | pixels_change_rate = 0.02 | ||||
| pixel_value_change_rate = 0.2 | pixel_value_change_rate = 0.2 | ||||
| diff = np.array(seed - mutate_test).flatten() | |||||
| diff = np.array(seed - mutate_sample).flatten() | |||||
| size = np.shape(diff)[0] | size = np.shape(diff)[0] | ||||
| l0 = np.linalg.norm(diff, ord=0) | l0 = np.linalg.norm(diff, ord=0) | ||||
| linf = np.linalg.norm(diff, ord=np.inf) | linf = np.linalg.norm(diff, ord=np.inf) | ||||
| @@ -167,5 +320,4 @@ class Fuzzer: | |||||
| else: | else: | ||||
| if linf < pixel_value_change_rate*255: | if linf < pixel_value_change_rate*255: | ||||
| is_valid = True | is_valid = True | ||||
| return is_valid | return is_valid | ||||
| @@ -88,7 +88,8 @@ def is_rgb(img): | |||||
| Bool, True if input is RGB. | Bool, True if input is RGB. | ||||
| """ | """ | ||||
| if is_numpy(img): | if is_numpy(img): | ||||
| if len(np.shape(img)) == 3: | |||||
| img_shape = np.shape(img) | |||||
| if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3): | |||||
| return True | return True | ||||
| return False | return False | ||||
| raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | ||||
| @@ -127,6 +128,7 @@ class ImageTransform: | |||||
| of the image is not normalized , it will be normalized between 0 to 1.""" | of the image is not normalized , it will be normalized between 0 to 1.""" | ||||
| rgb = is_rgb(image) | rgb = is_rgb(image) | ||||
| chw = False | chw = False | ||||
| gray3dim = False | |||||
| normalized = is_normalized(image) | normalized = is_normalized(image) | ||||
| if rgb: | if rgb: | ||||
| chw = is_chw(image) | chw = is_chw(image) | ||||
| @@ -135,12 +137,16 @@ class ImageTransform: | |||||
| else: | else: | ||||
| image = image | image = image | ||||
| else: | else: | ||||
| image = image | |||||
| if len(np.shape(image)) == 3: | |||||
| gray3dim = True | |||||
| image = image[0] | |||||
| else: | |||||
| image = image | |||||
| if normalized: | if normalized: | ||||
| image = np.uint8(image*255) | image = np.uint8(image*255) | ||||
| return rgb, chw, normalized, image | |||||
| return rgb, chw, normalized, gray3dim, image | |||||
| def _original_format(self, image, chw, normalized): | |||||
| def _original_format(self, image, chw, normalized, gray3dim): | |||||
| """ Return transformed image with original format. """ | """ Return transformed image with original format. """ | ||||
| if not is_numpy(image): | if not is_numpy(image): | ||||
| image = np.array(image) | image = np.array(image) | ||||
| @@ -148,6 +154,8 @@ class ImageTransform: | |||||
| image = hwc_to_chw(image) | image = hwc_to_chw(image) | ||||
| if normalized: | if normalized: | ||||
| image = image / 255 | image = image / 255 | ||||
| if gray3dim: | |||||
| image = np.expand_dims(image, 0) | |||||
| return image | return image | ||||
| def transform(self, image): | def transform(self, image): | ||||
| @@ -191,11 +199,12 @@ class Contrast(ImageTransform): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, transformed image. | numpy.ndarray, transformed image. | ||||
| """ | """ | ||||
| _, chw, normalized, image = self._check(image) | |||||
| _, chw, normalized, gray3dim, image = self._check(image) | |||||
| image = to_pil(image) | image = to_pil(image) | ||||
| img_contrast = ImageEnhance.Contrast(image) | img_contrast = ImageEnhance.Contrast(image) | ||||
| trans_image = img_contrast.enhance(self.factor) | trans_image = img_contrast.enhance(self.factor) | ||||
| trans_image = self._original_format(trans_image, chw, normalized) | |||||
| trans_image = self._original_format(trans_image, chw, normalized, | |||||
| gray3dim) | |||||
| return trans_image | return trans_image | ||||
| @@ -237,11 +246,12 @@ class Brightness(ImageTransform): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, transformed image. | numpy.ndarray, transformed image. | ||||
| """ | """ | ||||
| _, chw, normalized, image = self._check(image) | |||||
| _, chw, normalized, gray3dim, image = self._check(image) | |||||
| image = to_pil(image) | image = to_pil(image) | ||||
| img_contrast = ImageEnhance.Brightness(image) | img_contrast = ImageEnhance.Brightness(image) | ||||
| trans_image = img_contrast.enhance(self.factor) | trans_image = img_contrast.enhance(self.factor) | ||||
| trans_image = self._original_format(trans_image, chw, normalized) | |||||
| trans_image = self._original_format(trans_image, chw, normalized, | |||||
| gray3dim) | |||||
| return trans_image | return trans_image | ||||
| @@ -280,10 +290,11 @@ class Blur(ImageTransform): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, transformed image. | numpy.ndarray, transformed image. | ||||
| """ | """ | ||||
| _, chw, normalized, image = self._check(image) | |||||
| _, chw, normalized, gray3dim, image = self._check(image) | |||||
| image = to_pil(image) | image = to_pil(image) | ||||
| trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) | trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) | ||||
| trans_image = self._original_format(trans_image, chw, normalized) | |||||
| trans_image = self._original_format(trans_image, chw, normalized, | |||||
| gray3dim) | |||||
| return trans_image | return trans_image | ||||
| @@ -324,12 +335,13 @@ class Noise(ImageTransform): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, transformed image. | numpy.ndarray, transformed image. | ||||
| """ | """ | ||||
| _, chw, normalized, image = self._check(image) | |||||
| _, chw, normalized, gray3dim, image = self._check(image) | |||||
| noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) | noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) | ||||
| trans_image = np.copy(image) | trans_image = np.copy(image) | ||||
| trans_image[noise < -self.factor] = 0 | trans_image[noise < -self.factor] = 0 | ||||
| trans_image[noise > self.factor] = 1 | trans_image[noise > self.factor] = 1 | ||||
| trans_image = self._original_format(trans_image, chw, normalized) | |||||
| trans_image = self._original_format(trans_image, chw, normalized, | |||||
| gray3dim) | |||||
| return trans_image | return trans_image | ||||
| @@ -375,7 +387,7 @@ class Translate(ImageTransform): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, transformed image. | numpy.ndarray, transformed image. | ||||
| """ | """ | ||||
| _, chw, normalized, image = self._check(image) | |||||
| _, chw, normalized, gray3dim, image = self._check(image) | |||||
| img = to_pil(image) | img = to_pil(image) | ||||
| if self.auto_param: | if self.auto_param: | ||||
| image_shape = np.shape(image) | image_shape = np.shape(image) | ||||
| @@ -383,7 +395,8 @@ class Translate(ImageTransform): | |||||
| self.y_bias = image_shape[1]*self.y_bias | self.y_bias = image_shape[1]*self.y_bias | ||||
| trans_image = img.transform(img.size, Image.AFFINE, | trans_image = img.transform(img.size, Image.AFFINE, | ||||
| (1, 0, self.x_bias, 0, 1, self.y_bias)) | (1, 0, self.x_bias, 0, 1, self.y_bias)) | ||||
| trans_image = self._original_format(trans_image, chw, normalized) | |||||
| trans_image = self._original_format(trans_image, chw, normalized, | |||||
| gray3dim) | |||||
| return trans_image | return trans_image | ||||
| @@ -431,7 +444,7 @@ class Scale(ImageTransform): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, transformed image. | numpy.ndarray, transformed image. | ||||
| """ | """ | ||||
| rgb, chw, normalized, image = self._check(image) | |||||
| rgb, chw, normalized, gray3dim, image = self._check(image) | |||||
| if rgb: | if rgb: | ||||
| h, w, _ = np.shape(image) | h, w, _ = np.shape(image) | ||||
| else: | else: | ||||
| @@ -442,7 +455,8 @@ class Scale(ImageTransform): | |||||
| trans_image = img.transform(img.size, Image.AFFINE, | trans_image = img.transform(img.size, Image.AFFINE, | ||||
| (self.factor_x, 0, move_x_centor, | (self.factor_x, 0, move_x_centor, | ||||
| 0, self.factor_y, move_y_centor)) | 0, self.factor_y, move_y_centor)) | ||||
| trans_image = self._original_format(trans_image, chw, normalized) | |||||
| trans_image = self._original_format(trans_image, chw, normalized, | |||||
| gray3dim) | |||||
| return trans_image | return trans_image | ||||
| @@ -500,7 +514,7 @@ class Shear(ImageTransform): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, transformed image. | numpy.ndarray, transformed image. | ||||
| """ | """ | ||||
| rgb, chw, normalized, image = self._check(image) | |||||
| rgb, chw, normalized, gray3dim, image = self._check(image) | |||||
| img = to_pil(image) | img = to_pil(image) | ||||
| if rgb: | if rgb: | ||||
| h, w, _ = np.shape(image) | h, w, _ = np.shape(image) | ||||
| @@ -523,7 +537,8 @@ class Shear(ImageTransform): | |||||
| trans_image = img.transform(img.size, Image.AFFINE, | trans_image = img.transform(img.size, Image.AFFINE, | ||||
| (scale, scale*self.factor_x, move_x_cen, | (scale, scale*self.factor_x, move_x_cen, | ||||
| scale*self.factor_y, scale, move_y_cen)) | scale*self.factor_y, scale, move_y_cen)) | ||||
| trans_image = self._original_format(trans_image, chw, normalized) | |||||
| trans_image = self._original_format(trans_image, chw, normalized, | |||||
| gray3dim) | |||||
| return trans_image | return trans_image | ||||
| @@ -562,8 +577,9 @@ class Rotate(ImageTransform): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, transformed image. | numpy.ndarray, transformed image. | ||||
| """ | """ | ||||
| _, chw, normalized, image = self._check(image) | |||||
| _, chw, normalized, gray3dim, image = self._check(image) | |||||
| img = to_pil(image) | img = to_pil(image) | ||||
| trans_image = img.rotate(self.angle, expand=True) | trans_image = img.rotate(self.angle, expand=True) | ||||
| trans_image = self._original_format(trans_image, chw, normalized) | |||||
| trans_image = self._original_format(trans_image, chw, normalized, | |||||
| gray3dim) | |||||
| return trans_image | return trans_image | ||||
| @@ -0,0 +1,172 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Model-fuzz coverage test. | |||||
| """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| from mindspore import context | |||||
| from mindspore import nn | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.train import Model | |||||
| from mindarmour.fuzzing.fuzzing import Fuzzer | |||||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'Fuzzing test' | |||||
| LOGGER.set_level('INFO') | |||||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
| weight = weight_variable() | |||||
| return nn.Conv2d(in_channels, out_channels, | |||||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||||
| def fc_with_initialize(input_channels, out_channels): | |||||
| weight = weight_variable() | |||||
| bias = weight_variable() | |||||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||||
| def weight_variable(): | |||||
| return TruncatedNormal(0.02) | |||||
| class Net(nn.Cell): | |||||
| """ | |||||
| Lenet network | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.conv1 = conv(1, 6, 5) | |||||
| self.conv2 = conv(6, 16, 5) | |||||
| self.fc1 = fc_with_initialize(16*5*5, 120) | |||||
| self.fc2 = fc_with_initialize(120, 84) | |||||
| self.fc3 = fc_with_initialize(84, 10) | |||||
| self.relu = nn.ReLU() | |||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
| self.reshape = P.Reshape() | |||||
| def construct(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.conv2(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.reshape(x, (-1, 16*5*5)) | |||||
| x = self.fc1(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc2(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc3(x) | |||||
| return x | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_fuzzing_ascend(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| # load network | |||||
| net = Net() | |||||
| model = Model(net) | |||||
| batch_size = 8 | |||||
| num_classe = 10 | |||||
| mutate_config = [{'method': 'Blur', | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'Contrast', | |||||
| 'params': {'factor': 2}}, | |||||
| {'method': 'Translate', | |||||
| 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||||
| {'method': 'FGSM', | |||||
| 'params': {'eps': 0.1, 'alpha': 0.1}} | |||||
| ] | |||||
| # initialize fuzz test with training dataset | |||||
| train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||||
| # fuzz test with original test data | |||||
| # get test data | |||||
| test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||||
| test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||||
| test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||||
| initial_seeds = [] | |||||
| # make initial seeds | |||||
| for img, label in zip(test_images, test_labels): | |||||
| initial_seeds.append([img, label, 0]) | |||||
| initial_seeds = initial_seeds[:100] | |||||
| model_coverage_test.calculate_coverage( | |||||
| np.array(test_images[:100]).astype(np.float32)) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||||
| model_coverage_test.get_kmnc()) | |||||
| model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||||
| print(metrics) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_fuzzing_cpu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| # load network | |||||
| net = Net() | |||||
| model = Model(net) | |||||
| batch_size = 8 | |||||
| num_classe = 10 | |||||
| mutate_config = [{'method': 'Blur', | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'Contrast', | |||||
| 'params': {'factor': 2}}, | |||||
| {'method': 'Translate', | |||||
| 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||||
| {'method': 'FGSM', | |||||
| 'params': {'eps': 0.1, 'alpha': 0.1}} | |||||
| ] | |||||
| # initialize fuzz test with training dataset | |||||
| train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||||
| # fuzz test with original test data | |||||
| # get test data | |||||
| test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||||
| test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||||
| test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||||
| initial_seeds = [] | |||||
| # make initial seeds | |||||
| for img, label in zip(test_images, test_labels): | |||||
| initial_seeds.append([img, label, 0]) | |||||
| initial_seeds = initial_seeds[:100] | |||||
| model_coverage_test.calculate_coverage( | |||||
| np.array(test_images[:100]).astype(np.float32)) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||||
| model_coverage_test.get_kmnc()) | |||||
| model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||||
| print(metrics) | |||||