diff --git a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py index b7f2a77..3f92dea 100644 --- a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py +++ b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py @@ -63,9 +63,11 @@ def test_lenet_mnist_fuzzing(): images = data[0].astype(np.float32) train_images.append(images) train_images = np.concatenate(train_images, axis=0) + neuron_num = 10 + segmented_num = 1000 # initialize fuzz test with training dataset - model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) + model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) # fuzz test with original test data # get test data @@ -93,9 +95,8 @@ def test_lenet_mnist_fuzzing(): LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) - model_fuzz_test = Fuzzer(model, train_images, 10, 1000) - _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, - eval_metrics='auto') + model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) + _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, eval_metrics='auto') if metrics: for key in metrics: LOGGER.info(TAG, key + ': %s', metrics[key]) diff --git a/mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py b/mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py index 8d3a18f..5dd3f3f 100644 --- a/mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py +++ b/mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py @@ -36,48 +36,45 @@ def _bound(image, epislon): class NES(Attack): """ - The class is an implementation of the Natural Evolutionary Strategies Attack, - including three settings: Query-Limited setting, Partial-Information setting - and Label-Only setting. + The class is an implementation of the Natural Evolutionary Strategies Attack + Method. NES uses natural evolutionary strategies to estimate gradients to + improve query efficiency. NES covers three settings: Query-Limited setting, + Partial-Information setting and Label-Only setting. In the query-limit + setting, the attack has a limited number of queries to the target model but + access to the probabilities of all classes. In the partial-info setting, + the attack only has access to the probabilities for top-k classes. + In the label-only setting, the attack only has access to a list of k inferred + labels ordered by their predicted probabilities. In the Partial-Information + setting and Label-Only setting, NES do target attack so user need to use + set_target_images method to set target images of target classes. References: `Andrew Ilyas, Logan Engstrom, Anish Athalye, and Jessy Lin. Black-box adversarial attacks with limited queries and information. In ICML, July 2018 `_ Args: - model (BlackModel): Target model. - scene (str): Scene in 'Label_Only', 'Partial_Info' or - 'Query_Limit'. - max_queries (int): Maximum query numbers to generate an adversarial - example. Default: 500000. - top_k (int): For Partial-Info or Label-Only setting, indicating how - much (Top-k) information is available for the attacker. For - Query-Limited setting, this input should be set as -1. Default: -1. + model (BlackModel): Target model to be attacked. + scene (str): Scene in 'Label_Only', 'Partial_Info' or 'Query_Limit'. + max_queries (int): Maximum query numbers to generate an adversarial example. Default: 500000. + top_k (int): For Partial-Info or Label-Only setting, indicating how much (Top-k) information is + available for the attacker. For Query-Limited setting, this input should be set as -1. Default: -1. num_class (int): Number of classes in dataset. Default: 10. batch_size (int): Batch size. Default: 96. epsilon (float): Maximum perturbation allowed in attack. Default: 0.3. - samples_per_draw (int): Number of samples draw in antithetic sampling. - Default: 96. + samples_per_draw (int): Number of samples draw in antithetic sampling. Default: 96. momentum (float): Momentum. Default: 0.9. learning_rate (float): Learning rate. Default: 1e-2. max_lr (float): Max Learning rate. Default: 1e-2. min_lr (float): Min Learning rate. Default: 5e-5. sigma (float): Step size of random noise. Default: 1e-3. - plateau_length (int): Length of plateau used in Annealing algorithm. - Default: 20. - plateau_drop (float): Drop of plateau used in Annealing algorithm. - Default: 2.0. + plateau_length (int): Length of plateau used in Annealing algorithm. Default: 20. + plateau_drop (float): Drop of plateau used in Annealing algorithm. Default: 2.0. adv_thresh (float): Threshold of adversarial. Default: 0.15. - zero_iters (int): Number of points to use for the proxy score. - Default: 10. - starting_eps (float): Starting epsilon used in Label-Only setting. - Default: 1.0. - starting_delta_eps (float): Delta epsilon used in Label-Only setting. - Default: 0.5. - label_only_sigma (float): Sigma used in Label-Only setting. - Default: 1e-3. - conservative (int): Conservation used in epsilon decay, it will - increase if no convergence. Default: 2. + zero_iters (int): Number of points to use for the proxy score. Default: 10. + starting_eps (float): Starting epsilon used in Label-Only setting. Default: 1.0. + starting_delta_eps (float): Delta epsilon used in Label-Only setting. Default: 0.5. + label_only_sigma (float): Sigma used in Label-Only setting. Default: 1e-3. + conservative (int): Conservation used in epsilon decay, it will increase if no convergence. Default: 2. sparse (bool): If True, input labels are sparse-encoded. If False, input labels are one-hot-encoded. Default: True. @@ -94,13 +91,10 @@ class NES(Attack): >>> tag, adv, queries = nes_instance.generate([initial_img], [target_class]) """ - def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10, - batch_size=128, epsilon=0.3, samples_per_draw=128, - momentum=0.9, learning_rate=1e-3, max_lr=5e-2, min_lr=5e-4, - sigma=1e-3, plateau_length=20, plateau_drop=2.0, - adv_thresh=0.25, zero_iters=10, starting_eps=1.0, - starting_delta_eps=0.5, label_only_sigma=1e-3, conservative=2, - sparse=True): + def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10, batch_size=128, epsilon=0.3, + samples_per_draw=128, momentum=0.9, learning_rate=1e-3, max_lr=5e-2, min_lr=5e-4, sigma=1e-3, + plateau_length=20, plateau_drop=2.0, adv_thresh=0.25, zero_iters=10, starting_eps=1.0, + starting_delta_eps=0.5, label_only_sigma=1e-3, conservative=2, sparse=True): super(NES, self).__init__() self._model = check_model('model', model, BlackModel) self._scene = scene @@ -108,17 +102,14 @@ class NES(Attack): self._max_queries = check_int_positive('max_queries', max_queries) self._num_class = check_int_positive('num_class', num_class) self._batch_size = check_int_positive('batch_size', batch_size) - self._samples_per_draw = check_int_positive('samples_per_draw', - samples_per_draw) + self._samples_per_draw = check_int_positive('samples_per_draw', samples_per_draw) self._goal_epsilon = check_value_positive('epsilon', epsilon) self._momentum = check_value_positive('momentum', momentum) - self._learning_rate = check_value_positive('learning_rate', - learning_rate) + self._learning_rate = check_value_positive('learning_rate', learning_rate) self._max_lr = check_value_positive('max_lr', max_lr) self._min_lr = check_value_positive('min_lr', min_lr) self._sigma = check_value_positive('sigma', sigma) - self._plateau_length = check_int_positive('plateau_length', - plateau_length) + self._plateau_length = check_int_positive('plateau_length', plateau_length) self._plateau_drop = check_value_positive('plateau_drop', plateau_drop) # partial information arguments self._k = top_k @@ -126,10 +117,8 @@ class NES(Attack): # label only arguments self._zero_iters = check_int_positive('zero_iters', zero_iters) self._starting_eps = check_value_positive('starting_eps', starting_eps) - self._starting_delta_eps = check_value_positive('starting_delta_eps', - starting_delta_eps) - self._label_only_sigma = check_value_positive('label_only_sigma', - label_only_sigma) + self._starting_delta_eps = check_value_positive('starting_delta_eps', starting_delta_eps) + self._label_only_sigma = check_value_positive('label_only_sigma', label_only_sigma) self._conservative = check_int_positive('conservative', conservative) self._sparse = check_param_type('sparse', sparse, bool) self.target_imgs = None @@ -152,39 +141,32 @@ class NES(Attack): - numpy.ndarray, query times for each sample. Raises: - ValueError: If the top_k less than 0 in Label-Only or Partial-Info - setting. - ValueError: If the target_imgs is None in Label-Only or - Partial-Info setting. - ValueError: If scene is not in ['Label_Only', 'Partial_Info', - 'Query_Limit'] + ValueError: If the top_k less than 0 in Label-Only or Partial-Info setting. + ValueError: If the target_imgs is None in Label-Only or Partial-Info setting. + ValueError: If scene is not in ['Label_Only', 'Partial_Info', 'Query_Limit'] Examples: >>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]], >>> [1, 2]) """ - inputs, labels = check_pair_numpy_param('inputs', inputs, - 'labels', labels) + inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) if not self._sparse: labels = np.argmax(labels, axis=1) if self._scene == 'Label_Only' or self._scene == 'Partial_Info': - if self._k < 0: - msg = "In 'Label_Only' or 'Partial_Info' mode, " \ - "'top_k' must more than 0." + if self._k < 1: + msg = "In 'Label_Only' or 'Partial_Info' mode, 'top_k' must more than 0." LOGGER.error(TAG, msg) raise ValueError(msg) if self.target_imgs is None: - msg = "In 'Label_Only' or 'Partial_Info' mode, " \ - "'target_imgs' must be set." + msg = "In 'Label_Only' or 'Partial_Info' mode, 'target_imgs' must be set." LOGGER.error(TAG, msg) raise ValueError(msg) elif self._scene == 'Query_Limit': self._k = self._num_class else: - msg = "scene must be string in 'Label_Only', " \ - "'Partial_Info' or 'Query_Limit' " + msg = "scene must be string in 'Label_Only', 'Partial_Info' or 'Query_Limit' " LOGGER.error(TAG, msg) raise ValueError(msg) @@ -201,7 +183,7 @@ class NES(Attack): def set_target_images(self, target_images): """ - Set target samples for target attack. + Set target samples for target attack in the Partial-Info setting or Label-Only setting. Args: target_images (numpy.ndarray): Target samples for target attack. @@ -253,8 +235,8 @@ class NES(Attack): return True, adv, num_queries # antithetic sampling noise - noise_pos = np.random.normal( - size=(self._batch_size // 2,) + origin_image.shape) + size = (self._batch_size // 2,) + origin_image.shape + noise_pos = np.random.normal(size=size) noise = np.concatenate((noise_pos, -noise_pos), axis=0) eval_points = adv + self._sigma*noise @@ -274,8 +256,7 @@ class NES(Attack): while current_lr >= self._min_lr: # in partial information only or label only setting if self._scene == 'Label_Only' or self._scene == 'Partial_Info': - proposed_epsilon = max(self._epsilon - prop_delta_eps, - goal_epsilon) + proposed_epsilon = max(self._epsilon - prop_delta_eps, goal_epsilon) lower, upper = _bound(origin_image, proposed_epsilon) proposed_adv = adv - current_lr*np.sign(gradient) proposed_adv = np.clip(proposed_adv, lower, upper) @@ -288,23 +269,19 @@ class NES(Attack): delta_epsilon = max(prop_delta_eps, 0.1) last_ls = [] adv = proposed_adv - self._epsilon = max( - self._epsilon - prop_delta_eps / self._conservative, - goal_epsilon) + self._epsilon = self._epsilon - prop_delta_eps / self._conservative + self._epsilon = max(self._epsilon, goal_epsilon) break elif current_lr >= self._min_lr*2: current_lr = current_lr / 2 - LOGGER.debug(TAG, "backtracking learning rate to %.3f", - current_lr) + LOGGER.debug(TAG, "backtracking learning rate to %.3f", current_lr) else: prop_delta_eps = prop_delta_eps / 2 if prop_delta_eps < 2e-3: LOGGER.debug(TAG, "Did not converge.") return False, adv, num_queries current_lr = self._max_lr - LOGGER.debug(TAG, - "backtracking epsilon to %.3f", - self._epsilon - prop_delta_eps) + LOGGER.debug(TAG, "backtracking epsilon to %.3f", self._epsilon - prop_delta_eps) # update the number of queries if self._scene == 'Label_Only': @@ -323,12 +300,10 @@ class NES(Attack): def _plateau_annealing(self, last_loss): last_loss = last_loss[-self._plateau_length:] - if last_loss[-1] > last_loss[0] and len( - last_loss) == self._plateau_length: + if last_loss[-1] > last_loss[0] and len(last_loss) == self._plateau_length: if self._max_lr > self._min_lr: LOGGER.debug(TAG, "Annealing max learning rate.") - self._max_lr = max(self._max_lr / self._plateau_drop, - self._min_lr) + self._max_lr = max(self._max_lr / self._plateau_drop, self._min_lr) last_loss = [] return last_loss @@ -346,8 +321,7 @@ class NES(Attack): Loss in Query-Limit setting. """ LOGGER.debug(TAG, 'enter the function _query_limit_loss().') - loss = self._softmax_cross_entropy_with_logit( - self._model.predict(eval_points)) + loss = self._softmax_cross_entropy_with_logit(self._model.predict(eval_points)) return loss, noise @@ -359,8 +333,7 @@ class NES(Attack): logit = self._model.predict(eval_points) loss = np.sort(softmax(logit, axis=1))[:, -self._k:] inds = np.argsort(logit)[:, -self._k:] - good_loss = np.where(np.equal(inds, self.target_class), loss, - np.zeros(np.shape(inds))) + good_loss = np.where(np.equal(inds, self.target_class), loss, np.zeros(np.shape(inds))) good_loss = np.max(good_loss, axis=1) losses = -np.log(good_loss) return losses, noise @@ -370,22 +343,16 @@ class NES(Attack): Loss in Label-Only setting. """ LOGGER.debug(TAG, 'enter the function _label_only_loss().') - tiled_points = np.tile(np.expand_dims(eval_points, 0), - [self._zero_iters, - *[1]*len(eval_points.shape)]) - noised_eval_im = tiled_points \ - + np.random.randn(self._zero_iters, - self._batch_size, - *origin_image.shape) \ - *self._label_only_sigma - noised_eval_im = np.reshape(noised_eval_im, ( - self._zero_iters*self._batch_size, *origin_image.shape)) + tiled_points = np.tile(np.expand_dims(eval_points, 0), [self._zero_iters, *[1]*len(eval_points.shape)]) + noised_eval_im = tiled_points + np.random.randn(self._zero_iters, + self._batch_size, + *origin_image.shape)*self._label_only_sigma + noised_eval_im = np.reshape(noised_eval_im, (self._zero_iters*self._batch_size, *origin_image.shape)) logits = self._model.predict(noised_eval_im) inds = np.argsort(logits)[:, -self._k:] real_inds = np.reshape(inds, (self._zero_iters, self._batch_size, -1)) rank_range = np.arange(1, self._k + 1, 1, dtype=np.float32) - tiled_rank_range = np.tile(np.reshape(rank_range, (1, 1, self._k)), - [self._zero_iters, self._batch_size, 1]) + tiled_rank_range = np.tile(np.reshape(rank_range, (1, 1, self._k)), [self._zero_iters, self._batch_size, 1]) batches_in = np.where(np.equal(real_inds, self.target_class), tiled_rank_range, np.zeros(np.shape(tiled_rank_range))) @@ -408,16 +375,13 @@ class NES(Attack): grads = [] for _ in range(self._samples_per_draw // self._batch_size): if self._scene == 'Label_Only': - loss, np_noise = self._label_only_loss(origin_image, - eval_points, - noise) + loss, np_noise = self._label_only_loss(origin_image, eval_points, noise) elif self._scene == 'Partial_Info': loss, np_noise = self._partial_info_loss(eval_points, noise) else: loss, np_noise = self._query_limit_loss(eval_points, noise) # only support three channel images - losses_tiled = np.tile(np.reshape(loss, (-1, 1, 1, 1)), - (1,) + origin_image.shape) + losses_tiled = np.tile(np.reshape(loss, (-1, 1, 1, 1)), (1,) + origin_image.shape) grad = np.mean(losses_tiled*np_noise, axis=0) / self._sigma grads.append(grad) diff --git a/mindarmour/fuzz_testing/fuzzing.py b/mindarmour/fuzz_testing/fuzzing.py index 6da8d4d..2ddd9ad 100644 --- a/mindarmour/fuzz_testing/fuzzing.py +++ b/mindarmour/fuzz_testing/fuzzing.py @@ -44,14 +44,15 @@ def _select_next(initial_seeds): def _coverage_gains(coverages): - """ Calculate the coverage gains of mutated samples. """ + """ Calculate the coverage gains of mutated samples.""" gains = [0] + coverages[:-1] gains = np.array(coverages) - np.array(gains) return gains def _is_trans_valid(seed, mutate_sample): - """ Check a mutated sample is valid. If the number of changed pixels in + """ + Check a mutated sample is valid. If the number of changed pixels in a seed is less than pixels_change_rate*size(seed), this mutate is valid. Else check the infinite norm of seed changes, if the value of the infinite norm less than pixel_value_change_rate*255, this mutate is @@ -80,21 +81,18 @@ def _check_eval_metrics(eval_metrics): available_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] for elem in eval_metrics: if elem not in available_metrics: - msg = 'metric in list `eval_metrics` must be in {}, but ' \ - 'got {}.'.format(available_metrics, elem) + msg = 'metric in list `eval_metrics` must be in {}, but got {}.'.format(available_metrics, elem) LOGGER.error(TAG, msg) raise ValueError(msg) eval_metrics_.append(elem.lower()) elif isinstance(eval_metrics, str): if eval_metrics != 'auto': - msg = "the value of `eval_metrics` must be 'auto' if it's type " \ - "is str, but got {}.".format(eval_metrics) + msg = "the value of `eval_metrics` must be 'auto' if it's type is str, but got {}.".format(eval_metrics) LOGGER.error(TAG, msg) raise ValueError(msg) eval_metrics_ = 'auto' else: - msg = "the type of `eval_metrics` must be str, list or tuple, " \ - "but got {}.".format(type(eval_metrics)) + msg = "the type of `eval_metrics` must be str, list or tuple, but got {}.".format(type(eval_metrics)) LOGGER.error(TAG, msg) raise TypeError(msg) return eval_metrics_ @@ -109,11 +107,9 @@ class Fuzzer: Args: target_model (Model): Target fuzz model. - train_dataset (numpy.ndarray): Training dataset used for determining - the neurons' output boundaries. + train_dataset (numpy.ndarray): Training dataset used for determining the neurons' output boundaries. neuron_num (int): The number of testing neurons. - segmented_num (int): The number of segmented sections of neurons' - output intervals. Default: 1000. + segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 1000. Examples: >>> net = Net() @@ -126,7 +122,9 @@ class Fuzzer: >>> {'method': 'FGSM', >>> 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}] >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) - >>> model_fuzz_test = Fuzzer(model, train_images, 10, 1000) + >>> neuron_num = 10 + >>> segmented_num = 1000 + >>> model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) >>> samples, labels, preds, strategies, report = model_fuzz_test.fuzz_testing(mutate_config, initial_seeds) """ @@ -134,10 +132,7 @@ class Fuzzer: segmented_num=1000): self._target_model = check_model('model', target_model, Model) train_dataset = check_numpy_param('train_dataset', train_dataset) - self._coverage_metrics = ModelCoverageMetrics(target_model, - neuron_num, - segmented_num, - train_dataset) + self._coverage_metrics = ModelCoverageMetrics(target_model, neuron_num, segmented_num, train_dataset) # Allowed mutate strategies so far. self._strategies = {'Contrast': Contrast, 'Brightness': Brightness, @@ -151,30 +146,19 @@ class Fuzzer: 'PGD': ProjectedGradientDescent, 'MDIIM': MomentumDiverseInputIterativeMethod} self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] - self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', - 'Noise'] + self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', 'Noise'] self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] self._attack_param_checklists = { - 'FGSM': {'eps': {'dtype': [float], - 'range': [0, 1]}, - 'alpha': {'dtype': [float], - 'range': [0, 1]}, + 'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]}, + 'alpha': {'dtype': [float], 'range': [0, 1]}, 'bounds': {'dtype': [tuple]}}, - 'PGD': {'eps': {'dtype': [float], - 'range': [0, 1]}, - 'eps_iter': { - 'dtype': [float], - 'range': [0, 1]}, - 'nb_iter': {'dtype': [int], - 'range': [0, 100000]}, + 'PGD': {'eps': {'dtype': [float], 'range': [0, 1]}, + 'eps_iter': {'dtype': [float], 'range': [0, 1]}, + 'nb_iter': {'dtype': [int], 'range': [0, 100000]}, 'bounds': {'dtype': [tuple]}}, - 'MDIIM': {'eps': {'dtype': [float], - 'range': [0, 1]}, - 'norm_level': {'dtype': [str, int], - 'range': [1, 2, '1', '2', 'l1', 'l2', - 'inf', 'np.inf']}, - 'prob': {'dtype': [float], - 'range': [0, 1]}, + 'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]}, + 'norm_level': {'dtype': [str, int], 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf']}, + 'prob': {'dtype': [float], 'range': [0, 1]}, 'bounds': {'dtype': [tuple]}}} def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', @@ -239,13 +223,11 @@ class Fuzzer: # Check parameters. eval_metrics_ = _check_eval_metrics(eval_metrics) if coverage_metric not in ['KMNC', 'NBC', 'SNAC']: - msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], " \ - "but got {}.".format(coverage_metric) + msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], but got {}.".format(coverage_metric) LOGGER.error(TAG, msg) raise ValueError(msg) max_iters = check_int_positive('max_iters', max_iters) - mutate_num_per_seed = check_int_positive('mutate_num_per_seed', - mutate_num_per_seed) + mutate_num_per_seed = check_int_positive('mutate_num_per_seed', mutate_num_per_seed) mutate_config = self._check_mutate_config(mutate_config) mutates = self._init_mutates(mutate_config) @@ -276,39 +258,30 @@ class Fuzzer: mutate_config, mutate_num_per_seed) # Calculate the coverages and predictions of generated samples. - coverages, predicts = self._get_coverages_and_predict(mutate_samples, - coverage_metric) + coverages, predicts = self._get_coverages_and_predict(mutate_samples, coverage_metric) coverage_gains = _coverage_gains(coverages) - for mutate, cov, pred, strategy in zip(mutate_samples, - coverage_gains, - predicts, mutate_strategies): + for mutate, cov, pred, strategy in zip(mutate_samples, coverage_gains, predicts, mutate_strategies): fuzz_samples.append(mutate[0]) true_labels.append(mutate[1]) fuzz_preds.append(pred) fuzz_strategies.append(strategy) # if the mutate samples has coverage gains add this samples in - # the initial seeds to guide new mutates. + # the initial_seeds to guide new mutates. if cov > 0: initial_seeds.append(mutate) seed, initial_seeds = _select_next(initial_seeds) iter_num += 1 metrics_report = None if eval_metrics_ is not None: - metrics_report = self._evaluate(fuzz_samples, - true_labels, - fuzz_preds, - fuzz_strategies, - eval_metrics_) + metrics_report = self._evaluate(fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, eval_metrics_) return fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, metrics_report - def _get_coverages_and_predict(self, mutate_samples, - coverage_metric="KNMC"): + def _get_coverages_and_predict(self, mutate_samples, coverage_metric="KNMC"): """ Calculate the coverages and predictions of generated samples.""" samples = [s[0] for s in mutate_samples] samples = np.array(samples) coverages = [] - predictions = self._target_model.predict( - Tensor(samples.astype(np.float32))) + predictions = self._target_model.predict(Tensor(samples.astype(np.float32))) predictions = predictions.asnumpy() for index in range(len(samples)): mutate = samples[:index + 1] @@ -349,10 +322,8 @@ class Fuzzer: mutate_sample = transform.transform(seed[0]) else: for param_name in selected_param: - transform.__setattr__('_' + str(param_name), - selected_param[param_name]) - mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), - np.array([seed[1]]))[0] + transform.__setattr__('_' + str(param_name), selected_param[param_name]) + mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0] if method not in self._pixel_value_trans_list: only_pixel_trans = 1 mutate_sample = [mutate_sample, seed[1], only_pixel_trans] @@ -372,8 +343,7 @@ class Fuzzer: for config in mutate_config: check_param_type("config", config, dict) if set(config.keys()) != {'method', 'params'}: - msg = "The key of each config must be in ('method', 'params'), " \ - "but got {}.".format(set(config.keys())) + msg = "The key of each config must be in ('method', 'params'), but got {}.".format(set(config.keys())) LOGGER.error(TAG, msg) raise KeyError(msg) @@ -382,8 +352,7 @@ class Fuzzer: # Method must be in the optional range. if method not in self._strategies.keys(): - msg = "Config methods must be in {}, but got {}." \ - .format(self._strategies.keys(), method) + msg = "Config methods must be in {}, but got {}.".format(self._strategies.keys(), method) LOGGER.error(TAG, msg) raise ValueError(msg) @@ -401,8 +370,7 @@ class Fuzzer: # Methods in `metate_config` should at least have one in the type of # pixel value based transform methods. if not has_pixel_trans: - msg = "mutate methods in mutate_config should at least have one " \ - "in {}".format(self._pixel_value_trans_list) + msg = "mutate methods in mutate_config should at least have one in {}".format(self._pixel_value_trans_list) raise ValueError(msg) return mutate_config @@ -424,11 +392,9 @@ class Fuzzer: 'but got its length as{}'.format(len(bounds)) raise ValueError(msg) for bound_value in bounds: - _ = check_param_multi_types('bound', bound_value, - [int, float]) + _ = check_param_multi_types('bound', bound_value, [int, float]) if bounds[0] >= bounds[1]: - msg = "upper bound must more than lower bound, " \ - "but upper bound got {}, lower bound " \ + msg = "upper bound must more than lower bound, but upper bound got {}, lower bound " \ "got {}".format(bounds[0], bounds[1]) raise ValueError(msg) elif param_name == 'norm_level': @@ -437,10 +403,7 @@ class Fuzzer: allow_type = self._attack_param_checklists[method][param_name]['dtype'] allow_range = self._attack_param_checklists[method][param_name]['range'] _ = check_param_multi_types(str(param_name), param_value, allow_type) - _ = check_param_in_range(str(param_name), - param_value, - allow_range[0], - allow_range[1]) + _ = check_param_in_range(str(param_name), param_value, allow_range[0], allow_range[1]) def _init_mutates(self, mutate_config): """ Check whether the mutate_config meet the specification.""" @@ -454,8 +417,7 @@ class Fuzzer: loss_fn = self._target_model._loss_fn if loss_fn is None: loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) - mutates[method] = self._strategies[method](network, - loss_fn=loss_fn) + mutates[method] = self._strategies[method](network, loss_fn=loss_fn) return mutates def _evaluate(self, fuzz_samples, true_labels, fuzz_preds, @@ -497,8 +459,7 @@ class Fuzzer: metrics_report['Attack_success_rate'] = attack_success_rate if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: - self._coverage_metrics.calculate_coverage( - np.array(fuzz_samples).astype(np.float32)) + self._coverage_metrics.calculate_coverage(np.array(fuzz_samples).astype(np.float32)) if metrics == 'auto' or 'kmnc' in metrics: kmnc = self._coverage_metrics.get_kmnc() diff --git a/mindarmour/fuzz_testing/model_coverage_metrics.py b/mindarmour/fuzz_testing/model_coverage_metrics.py index b623761..ce647c2 100644 --- a/mindarmour/fuzz_testing/model_coverage_metrics.py +++ b/mindarmour/fuzz_testing/model_coverage_metrics.py @@ -56,7 +56,9 @@ class ModelCoverageMetrics: >>> train_images = np.random.random((10000, 1, 32, 32)).astype(np.float32) >>> test_images = np.random.random((5000, 1, 32, 32)).astype(np.float32) >>> model = Model(net) - >>> model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) + >>> neuron_num = 10 + >>> segmented_num = 1000 + >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) >>> model_fuzz_test.calculate_coverage(test_images) >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) @@ -68,16 +70,14 @@ class ModelCoverageMetrics: self._segmented_num = check_int_positive('segmented_num', segmented_num) self._neuron_num = check_int_positive('neuron_num', neuron_num) if self._neuron_num > 1e+9: - msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError' \ - 'would occur' + msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError would occur' LOGGER.error(TAG, msg) raise ValueError(msg) train_dataset = check_numpy_param('train_dataset', train_dataset) self._lower_bounds = [np.inf]*self._neuron_num self._upper_bounds = [-np.inf]*self._neuron_num self._var = [0]*self._neuron_num - self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in - range(self._neuron_num)] + self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in range(self._neuron_num)] self._lower_corner_hits = [0]*self._neuron_num self._upper_corner_hits = [0]*self._neuron_num self._bounds_get(train_dataset) @@ -99,19 +99,16 @@ class ModelCoverageMetrics: inputs = train_dataset[i*batch_size: (i + 1)*batch_size] output = self._model.predict(Tensor(inputs)).asnumpy() output_mat.append(output) - lower_compare_array = np.concatenate( - [output, np.array([self._lower_bounds])], axis=0) + lower_compare_array = np.concatenate([output, np.array([self._lower_bounds])], axis=0) self._lower_bounds = np.min(lower_compare_array, axis=0) - upper_compare_array = np.concatenate( - [output, np.array([self._upper_bounds])], axis=0) + upper_compare_array = np.concatenate([output, np.array([self._upper_bounds])], axis=0) self._upper_bounds = np.max(upper_compare_array, axis=0) if batches == 0: output = self._model.predict(Tensor(train_dataset)).asnumpy() self._lower_bounds = np.min(output, axis=0) self._upper_bounds = np.max(output, axis=0) output_mat.append(output) - self._var = np.std(np.concatenate(np.array(output_mat), axis=0), - axis=0) + self._var = np.std(np.concatenate(np.array(output_mat), axis=0), axis=0) def _sections_hits_count(self, dataset, intervals): """ @@ -119,8 +116,7 @@ class ModelCoverageMetrics: Args: dataset (numpy.ndarray): Testing data. - intervals (list[float]): Segmentation intervals of neurons' - outputs. + intervals (list[float]): Segmentation intervals of neurons' outputs. """ dataset = check_numpy_param('dataset', dataset) batch_output = self._model.predict(Tensor(dataset)).asnumpy() @@ -142,11 +138,12 @@ class ModelCoverageMetrics: dataset (numpy.ndarray): Data for fuzz test. bias_coefficient (Union[int, float]): The coefficient used for changing the neurons' output boundaries. Default: 0. - batch_size (int): The number of samples in a predict batch. - Default: 32. + batch_size (int): The number of samples in a predict batch. Default: 32. Examples: - >>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) + >>> neuron_num = 10 + >>> segmented_num = 1000 + >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) >>> model_fuzz_test.calculate_coverage(test_images) """ @@ -158,12 +155,13 @@ class ModelCoverageMetrics: intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num batches = dataset.shape[0] // batch_size for i in range(batches): - self._sections_hits_count( - dataset[i*batch_size: (i + 1)*batch_size], intervals) + self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) def get_kmnc(self): """ - Get the metric of 'k-multisection neuron coverage'. + Get the metric of 'k-multisection neuron coverage'. KMNC measures how + thoroughly the given set of test inputs covers the range of neurons + output values derived from training dataset. Returns: float, the metric of 'k-multisection neuron coverage'. @@ -176,7 +174,11 @@ class ModelCoverageMetrics: def get_nbc(self): """ - Get the metric of 'neuron boundary coverage'. + Get the metric of 'neuron boundary coverage' :math` NBC =(|UpperCornerNeuron| + + |LowerCornerNeuron|)/(2*|N|)`, where :math`|N|` is the number of neurons, + NBC refers to the proportion of neurons whose neurons output value in + the test dataset exceeds the upper and lower bounds of the corresponding + neurons output value in the training dataset. Returns: float, the metric of 'neuron boundary coverage'. @@ -184,13 +186,15 @@ class ModelCoverageMetrics: Examples: >>> model_fuzz_test.get_nbc() """ - nbc = (np.sum(self._lower_corner_hits) + np.sum( - self._upper_corner_hits)) / (2*self._neuron_num) + nbc = (np.sum(self._lower_corner_hits) + np.sum(self._upper_corner_hits)) / (2*self._neuron_num) return nbc def get_snac(self): """ Get the metric of 'strong neuron activation coverage'. + :math` SNAC =|UpperCornerNeuron| / |N|`. SNAC refers to the proportion + of neurons whose neurons output value in the test set exceeds the upper + bounds of the corresponding neurons output value in the training set. Returns: float, the metric of 'strong neuron activation coverage'. diff --git a/tests/ut/python/fuzzing/test_coverage_metrics.py b/tests/ut/python/fuzzing/test_coverage_metrics.py index b4912a5..b8a7287 100644 --- a/tests/ut/python/fuzzing/test_coverage_metrics.py +++ b/tests/ut/python/fuzzing/test_coverage_metrics.py @@ -69,8 +69,10 @@ def test_lenet_mnist_coverage_cpu(): model = Model(net) # initialize fuzz test with training dataset + neuron_num = 10 + segmented_num = 1000 training_data = (np.random.random((10000, 10))*20).astype(np.float32) - model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data) + model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) # fuzz test with original test data # get test data @@ -103,8 +105,11 @@ def test_lenet_mnist_coverage_ascend(): model = Model(net) # initialize fuzz test with training dataset + neuron_num = 10 + segmented_num = 1000 training_data = (np.random.random((10000, 10))*20).astype(np.float32) - model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data) + + model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data,) # fuzz test with original test data # get test data diff --git a/tests/ut/python/fuzzing/test_fuzzer.py b/tests/ut/python/fuzzing/test_fuzzer.py index 381e608..5066bcf 100644 --- a/tests/ut/python/fuzzing/test_fuzzer.py +++ b/tests/ut/python/fuzzing/test_fuzzer.py @@ -101,8 +101,10 @@ def test_fuzzing_ascend(): 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} ] # initialize fuzz test with training dataset + neuron_num = 10 + segmented_num = 1000 train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) - model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) + model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) # fuzz test with original test data # get test data @@ -121,7 +123,7 @@ def test_fuzzing_ascend(): LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) - model_fuzz_test = Fuzzer(model, train_images, 10, 1000) + model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) print(metrics) @@ -137,6 +139,8 @@ def test_fuzzing_cpu(): model = Model(net) batch_size = 8 num_classe = 10 + neuron_num = 10 + segmented_num = 1000 mutate_config = [{'method': 'Blur', 'params': {'auto_param': [True]}}, {'method': 'Contrast', @@ -148,7 +152,7 @@ def test_fuzzing_cpu(): ] # initialize fuzz test with training dataset train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) - model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) + model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) # fuzz test with original test data # get test data @@ -167,6 +171,6 @@ def test_fuzzing_cpu(): LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) - model_fuzz_test = Fuzzer(model, train_images, 10, 1000) + model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) print(metrics)