| @@ -14,11 +14,10 @@ | |||
| import numpy as np | |||
| from mindspore import Model | |||
| from mindspore import context | |||
| from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindarmour.adv_robustness.attacks import FastGradientSignMethod | |||
| from mindarmour.fuzz_testing import ModelCoverageMetrics | |||
| from mindarmour.fuzz_testing.model_coverage_metrics import NeuronCoverage, TopKNeuronCoverage, NeuronBoundsCoverage,\ | |||
| SuperNeuronActivateCoverage, KMultisectionNeuronCoverage | |||
| from mindarmour.utils.logger import LogUtil | |||
| from examples.common.dataset.data_processing import generate_mnist_dataset | |||
| @@ -46,13 +45,6 @@ def test_lenet_mnist_coverage(): | |||
| images = data[0].astype(np.float32) | |||
| train_images.append(images) | |||
| train_images = np.concatenate(train_images, axis=0) | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| top_k = 3 | |||
| threshold = 0.1 | |||
| # initialize fuzz test with training dataset | |||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| @@ -67,31 +59,31 @@ def test_lenet_mnist_coverage(): | |||
| test_images.append(images) | |||
| test_labels.append(labels) | |||
| test_images = np.concatenate(test_images, axis=0) | |||
| test_labels = np.concatenate(test_labels, axis=0) | |||
| model_fuzz_test.calculate_coverage(test_images) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| # initialize fuzz test with training dataset | |||
| nc = NeuronCoverage(model, threshold=0.1) | |||
| nc_metric = nc.get_metrics(test_images) | |||
| tknc = TopKNeuronCoverage(model, top_k=3) | |||
| tknc_metrics = tknc.get_metrics(test_images) | |||
| snac = SuperNeuronActivateCoverage(model, train_images) | |||
| snac_metrics = snac.get_metrics(test_images) | |||
| nbc = NeuronBoundsCoverage(model, train_images) | |||
| nbc_metrics = nbc.get_metrics(test_images) | |||
| # generate adv_data | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True) | |||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | |||
| adv_data = attack.batch_generate(test_images, test_labels, batch_size=32) | |||
| model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | |||
| LOGGER.info(TAG, 'KMNC of this adv data is : %s', model_fuzz_test.get_kmnc()) | |||
| LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac()) | |||
| kmnc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) | |||
| kmnc_metrics = kmnc.get_metrics(test_images) | |||
| model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| print('KMNC of this test is: ', kmnc_metrics) | |||
| print('NBC of this test is: ', nbc_metrics) | |||
| print('SNAC of this test is: ', snac_metrics) | |||
| print('NC of this test is: ', nc_metric) | |||
| print('TKNC of this test is: ', tknc_metrics) | |||
| if __name__ == '__main__': | |||
| # device_target can be "CPU", "GPU" or "Ascend" | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| test_lenet_mnist_coverage() | |||
| @@ -14,11 +14,11 @@ | |||
| import numpy as np | |||
| from mindspore import Model | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore import load_checkpoint, load_param_into_net | |||
| from mindarmour.fuzz_testing import Fuzzer | |||
| from mindarmour.fuzz_testing import ModelCoverageMetrics | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.fuzz_testing import KMultisectionNeuronCoverage | |||
| from mindarmour.utils import LogUtil | |||
| from examples.common.dataset.data_processing import generate_mnist_dataset | |||
| from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||
| @@ -52,7 +52,7 @@ def test_lenet_mnist_fuzzing(): | |||
| 'params': {'auto_param': [True]}}, | |||
| {'method': 'FGSM', | |||
| 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}} | |||
| ] | |||
| ] | |||
| # get training data | |||
| data_list = "../common/dataset/MNIST/train" | |||
| @@ -63,11 +63,6 @@ def test_lenet_mnist_fuzzing(): | |||
| images = data[0].astype(np.float32) | |||
| train_images.append(images) | |||
| train_images = np.concatenate(train_images, axis=0) | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| # initialize fuzz test with training dataset | |||
| model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| @@ -88,21 +83,20 @@ def test_lenet_mnist_fuzzing(): | |||
| # make initial seeds | |||
| for img, label in zip(test_images, test_labels): | |||
| initial_seeds.append([img, label]) | |||
| coverage = KMultisectionNeuronCoverage(model, train_images, segmented_num=100, incremental=True) | |||
| kmnc = coverage.get_metrics(test_images[:100]) | |||
| print('KMNC of initial seeds is: ', kmnc) | |||
| initial_seeds = initial_seeds[:100] | |||
| model_coverage_test.calculate_coverage( | |||
| np.array(test_images[:100]).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| model_fuzz_test = Fuzzer(model) | |||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10, | |||
| mutate_num_per_seed=20) | |||
| model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) | |||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, eval_metrics='auto') | |||
| if metrics: | |||
| for key in metrics: | |||
| LOGGER.info(TAG, key + ': %s', metrics[key]) | |||
| print(key + ': ', metrics[key]) | |||
| if __name__ == '__main__': | |||
| # device_target can be "CPU", "GPU" or "Ascend" | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| # device_target can be "CPU"GPU, "" or "Ascend" | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| test_lenet_mnist_fuzzing() | |||
| @@ -20,19 +20,21 @@ from mindspore.ops import TensorSummary | |||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
| """Wrap conv.""" | |||
| weight = weight_variable() | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||
| return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, | |||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||
| def fc_with_initialize(input_channels, out_channels): | |||
| """Wrap initialize method of full connection layer.""" | |||
| weight = weight_variable() | |||
| bias = weight_variable() | |||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||
| def weight_variable(): | |||
| """Wrap initialize variable.""" | |||
| return TruncatedNormal(0.05) | |||
| @@ -50,7 +52,6 @@ class LeNet5(nn.Cell): | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flatten = nn.Flatten() | |||
| self.summary = TensorSummary() | |||
| def construct(self, x): | |||
| @@ -59,8 +60,6 @@ class LeNet5(nn.Cell): | |||
| Returns: | |||
| x (tensor): network output | |||
| """ | |||
| self.summary('input', x) | |||
| x = self.conv1(x) | |||
| self.summary('1', x) | |||
| @@ -16,7 +16,13 @@ This module provides a neuron coverage-gain based fuzz method to evaluate the | |||
| robustness of given model. | |||
| """ | |||
| from .fuzzing import Fuzzer | |||
| from .model_coverage_metrics import ModelCoverageMetrics | |||
| from .model_coverage_metrics import CoverageMetrics, NeuronCoverage, TopKNeuronCoverage, NeuronBoundsCoverage, \ | |||
| SuperNeuronActivateCoverage, KMultisectionNeuronCoverage | |||
| __all__ = ['Fuzzer', | |||
| 'ModelCoverageMetrics'] | |||
| 'CoverageMetrics', | |||
| 'NeuronCoverage', | |||
| 'TopKNeuronCoverage', | |||
| 'NeuronBoundsCoverage', | |||
| 'SuperNeuronActivateCoverage', | |||
| 'KMultisectionNeuronCoverage'] | |||
| @@ -21,15 +21,14 @@ from mindspore import Model | |||
| from mindspore import Tensor | |||
| from mindspore import nn | |||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
| check_param_multi_types, check_norm_level, check_param_in_range, \ | |||
| check_param_type, check_int_positive | |||
| from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \ | |||
| check_param_in_range, check_param_type, check_int_positive, check_param_bounds | |||
| from mindarmour.utils.logger import LogUtil | |||
| from ..adv_robustness.attacks import FastGradientSignMethod, \ | |||
| MomentumDiverseInputIterativeMethod, ProjectedGradientDescent | |||
| from .image_transform import Contrast, Brightness, Blur, \ | |||
| Noise, Translate, Scale, Shear, Rotate | |||
| from .model_coverage_metrics import ModelCoverageMetrics | |||
| from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Fuzzer' | |||
| @@ -43,11 +42,22 @@ def _select_next(initial_seeds): | |||
| return seed, initial_seeds | |||
| def _coverage_gains(coverages): | |||
| """ Calculate the coverage gains of mutated samples.""" | |||
| gains = [0] + coverages[:-1] | |||
| def _coverage_gains(pre_coverage, coverages): | |||
| """ | |||
| Calculate the coverage gains of mutated samples. | |||
| Args: | |||
| pre_coverage (float): Last value of coverages for previous mutated samples. | |||
| coverages (list): Coverage of mutated samples. | |||
| Returns: | |||
| - list, coverage gains for mutated samples. | |||
| - float, last value in parameter coverages. | |||
| """ | |||
| gains = [pre_coverage] + coverages[:-1] | |||
| gains = np.array(coverages) - np.array(gains) | |||
| return gains | |||
| return gains, coverages[-1] | |||
| def _is_trans_valid(seed, mutate_sample): | |||
| @@ -65,37 +75,22 @@ def _is_trans_valid(seed, mutate_sample): | |||
| size = np.shape(diff)[0] | |||
| l0_norm = np.linalg.norm(diff, ord=0) | |||
| linf = np.linalg.norm(diff, ord=np.inf) | |||
| if l0_norm > pixels_change_rate*size: | |||
| if l0_norm > pixels_change_rate * size: | |||
| if linf < 256: | |||
| is_valid = True | |||
| else: | |||
| if linf < pixel_value_change_rate*255: | |||
| if linf < pixel_value_change_rate * 255: | |||
| is_valid = True | |||
| return is_valid | |||
| def _check_eval_metrics(eval_metrics): | |||
| """ Check evaluation metrics.""" | |||
| if isinstance(eval_metrics, (list, tuple)): | |||
| eval_metrics_ = [] | |||
| available_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] | |||
| for elem in eval_metrics: | |||
| if elem not in available_metrics: | |||
| msg = 'metric in list `eval_metrics` must be in {}, but got {}.'.format(available_metrics, elem) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| eval_metrics_.append(elem.lower()) | |||
| elif isinstance(eval_metrics, str): | |||
| if eval_metrics != 'auto': | |||
| msg = "the value of `eval_metrics` must be 'auto' if it's type is str, but got {}.".format(eval_metrics) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| eval_metrics_ = 'auto' | |||
| def _gain_threshold(coverage): | |||
| """Get threshold for given neuron coverage class.""" | |||
| if coverage is isinstance(coverage, KMultisectionNeuronCoverage): | |||
| gain_threshold = 0.1 / coverage.segmented_num | |||
| else: | |||
| msg = "the type of `eval_metrics` must be str, list or tuple, but got {}.".format(type(eval_metrics)) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| return eval_metrics_ | |||
| gain_threshold = 0 | |||
| return gain_threshold | |||
| class Fuzzer: | |||
| @@ -113,6 +108,7 @@ class Fuzzer: | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> model = Model(net) | |||
| >>> mutate_config = [{'method': 'Blur', | |||
| >>> 'params': {'auto_param': [True]}}, | |||
| >>> {'method': 'Contrast', | |||
| @@ -121,18 +117,15 @@ class Fuzzer: | |||
| >>> 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}}, | |||
| >>> {'method': 'FGSM', | |||
| >>> 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}] | |||
| >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
| >>> neuron_num = 10 | |||
| >>> segmented_num = 1000 | |||
| >>> model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) | |||
| >>> samples, labels, preds, strategies, report = model_fuzz_test.fuzz_testing(mutate_config, initial_seeds) | |||
| >>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) | |||
| >>> model_fuzz_test = Fuzzer(model) | |||
| >>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, | |||
| >>> nc, max_iters=100) | |||
| """ | |||
| def __init__(self, target_model, train_dataset, neuron_num, | |||
| segmented_num=1000): | |||
| def __init__(self, target_model): | |||
| self._target_model = check_model('model', target_model, Model) | |||
| train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
| self._coverage_metrics = ModelCoverageMetrics(target_model, neuron_num, segmented_num, train_dataset) | |||
| # Allowed mutate strategies so far. | |||
| self._strategies = {'Contrast': Contrast, | |||
| 'Brightness': Brightness, | |||
| @@ -161,8 +154,7 @@ class Fuzzer: | |||
| 'prob': {'dtype': [float], 'range': [0, 1]}, | |||
| 'bounds': {'dtype': [tuple]}}} | |||
| def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', | |||
| eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20): | |||
| def fuzzing(self, mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10000, mutate_num_per_seed=20): | |||
| """ | |||
| Fuzzing tests for deep neural networks. | |||
| @@ -175,32 +167,20 @@ class Fuzzer: | |||
| {'method': 'FGSM', | |||
| 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}}, | |||
| ...]. | |||
| The supported methods list is in `self._strategies`, and the | |||
| params of each method must within the range of optional parameters. | |||
| Supported methods are grouped in three types: | |||
| Firstly, pixel value based transform methods include: | |||
| 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine | |||
| transform methods include: 'Translate', 'Scale', 'Shear' and | |||
| 'Rotate'. Thirdly, attack methods include: 'FGSM', 'PGD' and 'MDIIM'. | |||
| `mutate_config` must have method in the type of pixel value based | |||
| transform methods. The way of setting parameters for first and | |||
| second type methods can be seen in 'mindarmour/fuzz_testing/image_transform.py'. | |||
| For third type methods, the optional parameters refer to | |||
| The supported methods list is in `self._strategies`, and the params of each method must within the | |||
| range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based | |||
| transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform | |||
| methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM', | |||
| 'PGD' and 'MDIIM'. `mutate_config` must have method in the type of pixel value based transform methods. | |||
| The way of setting parameters for first and second type methods can be seen in | |||
| 'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to | |||
| `self._attack_param_checklists`. | |||
| initial_seeds (list[list]): Initial seeds used to generate mutated | |||
| samples. The format of initial seeds is [[image_data, label], | |||
| [...], ...] and the label must be one-hot. | |||
| coverage_metric (str): Model coverage metric of neural networks. All | |||
| supported metrics are: 'KMNC', 'NBC', 'SNAC'. Default: 'KMNC'. | |||
| eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the | |||
| type is 'auto', it will calculate all the metrics, else if the | |||
| type is list or tuple, it will calculate the metrics specified | |||
| by user. All supported evaluate methods are 'accuracy', | |||
| 'attack_success_rate', 'kmnc', 'nbc', 'snac'. Default: 'auto'. | |||
| max_iters (int): Max number of select a seed to mutate. | |||
| Default: 10000. | |||
| mutate_num_per_seed (int): The number of mutate times for a seed. | |||
| Default: 20. | |||
| initial_seeds (list[list]): Initial seeds used to generate mutated samples. The format of initial seeds is | |||
| [[image_data, label], [...], ...] and the label must be one-hot. | |||
| coverage (CoverageMetrics): Class of neuron coverage metrics. | |||
| evaluate (bool): return evaluate report or not. Default: True. | |||
| max_iters (int): Max number of select a seed to mutate. Default: 10000. | |||
| mutate_num_per_seed (int): The number of mutate times for a seed. Default: 20. | |||
| Returns: | |||
| - list, mutated samples in fuzz_testing. | |||
| @@ -214,18 +194,18 @@ class Fuzzer: | |||
| - dict, metrics report of fuzzer. | |||
| Raises: | |||
| TypeError: If the type of `eval_metrics` is not str, list or tuple. | |||
| TypeError: If the type of metric in list `eval_metrics` is not str. | |||
| ValueError: If `eval_metrics` is not equal to 'auto' when it's type is str. | |||
| ValueError: If metric in list `eval_metrics` is not in ['accuracy', | |||
| 'attack_success_rate', 'kmnc', 'nbc', 'snac']. | |||
| ValueError, coverage must be subclass of CoverageMetrics. | |||
| ValueError, if initial seeds is empty. | |||
| ValueError, if element of seed is not two in initial seeds. | |||
| """ | |||
| # Check parameters. | |||
| eval_metrics_ = _check_eval_metrics(eval_metrics) | |||
| if coverage_metric not in ['KMNC', 'NBC', 'SNAC']: | |||
| msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], but got {}.".format(coverage_metric) | |||
| if not isinstance(coverage, CoverageMetrics): | |||
| msg = 'coverage must be subclass of CoverageMetrics' | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| evaluate = check_param_type('evaluate', evaluate, bool) | |||
| max_iters = check_int_positive('max_iters', max_iters) | |||
| mutate_num_per_seed = check_int_positive('mutate_num_per_seed', mutate_num_per_seed) | |||
| mutate_config = self._check_mutate_config(mutate_config) | |||
| @@ -235,15 +215,21 @@ class Fuzzer: | |||
| if not initial_seeds: | |||
| msg = 'initial_seeds must not be empty.' | |||
| raise ValueError(msg) | |||
| initial_samples = [] | |||
| for seed in initial_seeds: | |||
| check_param_type('seed', seed, list) | |||
| if len(seed) != 2: | |||
| msg = 'seed in initial seeds must have two element image and ' \ | |||
| 'label, but got {} element.'.format(len(seed)) | |||
| msg = 'seed in initial seeds must have two element image and label, but got {} element.'.format( | |||
| len(seed)) | |||
| raise ValueError(msg) | |||
| check_numpy_param('seed[0]', seed[0]) | |||
| check_numpy_param('seed[1]', seed[1]) | |||
| initial_samples.append(seed[0]) | |||
| seed.append(0) | |||
| initial_samples = np.array(initial_samples) | |||
| # calculate the coverage of initial seeds | |||
| pre_coverage = coverage.get_metrics(initial_samples) | |||
| gain_threshold = _gain_threshold(coverage) | |||
| seed, initial_seeds = _select_next(initial_seeds) | |||
| fuzz_samples = [] | |||
| @@ -253,30 +239,27 @@ class Fuzzer: | |||
| iter_num = 0 | |||
| while initial_seeds and iter_num < max_iters: | |||
| # Mutate a seed. | |||
| mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, | |||
| mutates, | |||
| mutate_config, | |||
| mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, mutates, mutate_config, | |||
| mutate_num_per_seed) | |||
| # Calculate the coverages and predictions of generated samples. | |||
| coverages, predicts = self._get_coverages_and_predict(mutate_samples, coverage_metric) | |||
| coverage_gains = _coverage_gains(coverages) | |||
| coverages, predicts = self._get_coverages_and_predict(mutate_samples, coverage) | |||
| coverage_gains, pre_coverage = _coverage_gains(pre_coverage, coverages) | |||
| for mutate, cov, pred, strategy in zip(mutate_samples, coverage_gains, predicts, mutate_strategies): | |||
| fuzz_samples.append(mutate[0]) | |||
| true_labels.append(mutate[1]) | |||
| fuzz_preds.append(pred) | |||
| fuzz_strategies.append(strategy) | |||
| # if the mutate samples has coverage gains add this samples in | |||
| # the initial_seeds to guide new mutates. | |||
| if cov > 0: | |||
| # if the mutate samples has coverage gains add this samples in the initial_seeds to guide new mutates. | |||
| if cov > gain_threshold: | |||
| initial_seeds.append(mutate) | |||
| seed, initial_seeds = _select_next(initial_seeds) | |||
| iter_num += 1 | |||
| metrics_report = None | |||
| if eval_metrics_ is not None: | |||
| metrics_report = self._evaluate(fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, eval_metrics_) | |||
| if evaluate: | |||
| metrics_report = self._evaluate(fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, coverage) | |||
| return fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, metrics_report | |||
| def _get_coverages_and_predict(self, mutate_samples, coverage_metric="KNMC"): | |||
| def _get_coverages_and_predict(self, mutate_samples, coverage): | |||
| """ Calculate the coverages and predictions of generated samples.""" | |||
| samples = [s[0] for s in mutate_samples] | |||
| samples = np.array(samples) | |||
| @@ -285,17 +268,10 @@ class Fuzzer: | |||
| predictions = predictions.asnumpy() | |||
| for index in range(len(samples)): | |||
| mutate = samples[:index + 1] | |||
| self._coverage_metrics.calculate_coverage(mutate.astype(np.float32)) | |||
| if coverage_metric == 'KMNC': | |||
| coverages.append(self._coverage_metrics.get_kmnc()) | |||
| if coverage_metric == 'NBC': | |||
| coverages.append(self._coverage_metrics.get_nbc()) | |||
| if coverage_metric == 'SNAC': | |||
| coverages.append(self._coverage_metrics.get_snac()) | |||
| coverages.append(coverage.get_metrics(mutate)) | |||
| return coverages, predictions | |||
| def _metamorphic_mutate(self, seed, mutates, mutate_config, | |||
| mutate_num_per_seed): | |||
| def _metamorphic_mutate(self, seed, mutates, mutate_config, mutate_num_per_seed): | |||
| """Mutate a seed using strategies random selected from mutate_config.""" | |||
| mutate_samples = [] | |||
| mutate_strategies = [] | |||
| @@ -310,8 +286,8 @@ class Fuzzer: | |||
| params = strategy['params'] | |||
| method = strategy['method'] | |||
| selected_param = {} | |||
| for p in params: | |||
| selected_param[p] = choice(params[p]) | |||
| for param in params: | |||
| selected_param[param] = choice(params[param]) | |||
| if method in list(self._pixel_value_trans_list + self._affine_trans_list): | |||
| if method == 'Shear': | |||
| @@ -367,8 +343,7 @@ class Fuzzer: | |||
| else: | |||
| for key in params.keys(): | |||
| check_param_type(str(key), params[key], list) | |||
| # Methods in `metate_config` should at least have one in the type of | |||
| # pixel value based transform methods. | |||
| # Methods in `metate_config` should at least have one in the type of pixel value based transform methods. | |||
| if not has_pixel_trans: | |||
| msg = "mutate methods in mutate_config should at least have one in {}".format(self._pixel_value_trans_list) | |||
| raise ValueError(msg) | |||
| @@ -386,17 +361,7 @@ class Fuzzer: | |||
| check_param_type(param_name, params[param_name], list) | |||
| for param_value in params[param_name]: | |||
| if param_name == 'bounds': | |||
| bounds = check_param_multi_types('bounds', param_value, [tuple]) | |||
| if len(bounds) != 2: | |||
| msg = 'The format of bounds must be format (lower_bound, upper_bound),' \ | |||
| 'but got its length as{}'.format(len(bounds)) | |||
| raise ValueError(msg) | |||
| for bound_value in bounds: | |||
| _ = check_param_multi_types('bound', bound_value, [int, float]) | |||
| if bounds[0] >= bounds[1]: | |||
| msg = "upper bound must more than lower bound, but upper bound got {}, lower bound " \ | |||
| "got {}".format(bounds[0], bounds[1]) | |||
| raise ValueError(msg) | |||
| _ = check_param_bounds('bounds', param_name) | |||
| elif param_name == 'norm_level': | |||
| _ = check_norm_level(param_value) | |||
| else: | |||
| @@ -420,57 +385,40 @@ class Fuzzer: | |||
| mutates[method] = self._strategies[method](network, loss_fn=loss_fn) | |||
| return mutates | |||
| def _evaluate(self, fuzz_samples, true_labels, fuzz_preds, | |||
| fuzz_strategies, metrics): | |||
| def _evaluate(self, fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, coverage): | |||
| """ | |||
| Evaluate generated fuzz_testing samples in three dimensions: accuracy, | |||
| attack success rate and neural coverage. | |||
| Evaluate generated fuzz_testing samples in three dimensions: accuracy, attack success rate and neural coverage. | |||
| Args: | |||
| fuzz_samples ([numpy.ndarray, list]): Generated fuzz_testing samples | |||
| according to seeds. | |||
| fuzz_samples ([numpy.ndarray, list]): Generated fuzz_testing samples according to seeds. | |||
| true_labels ([numpy.ndarray, list]): Ground truth labels of seeds. | |||
| fuzz_preds ([numpy.ndarray, list]): Predictions of generated fuzz samples. | |||
| fuzz_strategies ([numpy.ndarray, list]): Mutate strategies of fuzz samples. | |||
| metrics (Union[list, tuple, str]): evaluation metrics. | |||
| coverage (CoverageMetrics): Neuron coverage metrics class. | |||
| Returns: | |||
| dict, evaluate metrics include accuracy, attack success rate | |||
| and neural coverage. | |||
| dict, evaluate metrics include accuracy, attack success rate and neural coverage. | |||
| """ | |||
| fuzz_samples = np.array(fuzz_samples) | |||
| true_labels = np.asarray(true_labels) | |||
| fuzz_preds = np.asarray(fuzz_preds) | |||
| temp = np.argmax(true_labels, axis=1) == np.argmax(fuzz_preds, axis=1) | |||
| metrics_report = {} | |||
| if metrics == 'auto' or 'accuracy' in metrics: | |||
| if temp.any(): | |||
| acc = np.sum(temp) / np.size(temp) | |||
| else: | |||
| acc = 0 | |||
| metrics_report['Accuracy'] = acc | |||
| if metrics == 'auto' or 'attack_success_rate' in metrics: | |||
| cond = [elem in self._attacks_list for elem in fuzz_strategies] | |||
| temp = temp[cond] | |||
| if temp.any(): | |||
| attack_success_rate = 1 - np.sum(temp) / np.size(temp) | |||
| else: | |||
| attack_success_rate = None | |||
| metrics_report['Attack_success_rate'] = attack_success_rate | |||
| if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: | |||
| self._coverage_metrics.calculate_coverage(np.array(fuzz_samples).astype(np.float32)) | |||
| if metrics == 'auto' or 'kmnc' in metrics: | |||
| kmnc = self._coverage_metrics.get_kmnc() | |||
| metrics_report['Neural_coverage_KMNC'] = kmnc | |||
| if metrics == 'auto' or 'nbc' in metrics: | |||
| nbc = self._coverage_metrics.get_nbc() | |||
| metrics_report['Neural_coverage_NBC'] = nbc | |||
| if metrics == 'auto' or 'snac' in metrics: | |||
| snac = self._coverage_metrics.get_snac() | |||
| metrics_report['Neural_coverage_SNAC'] = snac | |||
| if temp.any(): | |||
| acc = np.sum(temp) / np.size(temp) | |||
| else: | |||
| acc = 0 | |||
| metrics_report['Accuracy'] = acc | |||
| cond = [elem in self._attacks_list for elem in fuzz_strategies] | |||
| temp = temp[cond] | |||
| if temp.any(): | |||
| attack_success_rate = 1 - np.sum(temp) / np.size(temp) | |||
| else: | |||
| attack_success_rate = None | |||
| metrics_report['Attack_success_rate'] = attack_success_rate | |||
| metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples) | |||
| return metrics_report | |||
| @@ -14,311 +14,396 @@ | |||
| """ | |||
| Model-Test Coverage Metrics. | |||
| """ | |||
| from abc import abstractmethod | |||
| from collections import defaultdict | |||
| import math | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore import Model | |||
| from mindspore.train.summary.summary_record import _get_summary_tensor_data | |||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
| check_int_positive, check_param_multi_types | |||
| from mindarmour.utils._check_param import check_model, check_numpy_param, check_int_positive, \ | |||
| check_param_type, check_value_positive | |||
| from mindarmour.utils.logger import LogUtil | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'ModelCoverageMetrics' | |||
| TAG = 'CoverageMetrics' | |||
| class ModelCoverageMetrics: | |||
| class CoverageMetrics: | |||
| """ | |||
| As we all known, each neuron output of a network will have a output range | |||
| after training (we call it original range), and test dataset is used to | |||
| estimate the accuracy of the trained network. However, neurons' output | |||
| distribution would be different with different test datasets. Therefore, | |||
| similar to function fuzz, model fuzz means testing those neurons' outputs | |||
| and estimating the proportion of original range that has emerged with test | |||
| The abstract base class for Neuron coverage classes calculating coverage metrics. | |||
| As we all known, each neuron output of a network will have a output range after training (we call it original | |||
| range), and test dataset is used to estimate the accuracy of the trained network. However, neurons' output | |||
| distribution would be different with different test datasets. Therefore, similar to function fuzz, model fuzz means | |||
| testing those neurons' outputs and estimating the proportion of original range that has emerged with test | |||
| datasets. | |||
| Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep | |||
| Learning Systems <https://arxiv.org/abs/1803.07519>`_ | |||
| Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep Learning Systems | |||
| <https://arxiv.org/abs/1803.07519>`_ | |||
| Args: | |||
| model (Model): The pre-trained model which waiting for testing. | |||
| neuron_num (int): The number of testing neurons. | |||
| segmented_num (int): The number of segmented sections of neurons' output intervals. | |||
| train_dataset (numpy.ndarray): Training dataset used for determine | |||
| the neurons' output boundaries. | |||
| Raises: | |||
| ValueError: If neuron_num is too big (for example, bigger than 1e+9). | |||
| Examples: | |||
| >>> net = LeNet5() | |||
| >>> train_images = np.random.random((10000, 1, 32, 32)).astype(np.float32) | |||
| >>> test_images = np.random.random((5000, 1, 32, 32)).astype(np.float32) | |||
| >>> model = Model(net) | |||
| >>> neuron_num = 10 | |||
| >>> segmented_num = 1000 | |||
| >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| >>> model_fuzz_test.calculate_coverage(test_images) | |||
| >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
| >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||
| >>> print('NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| >>> print('Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | |||
| batch_size (int): The number of samples in a fuzz test batch. Default: 32. | |||
| """ | |||
| def __init__(self, model, neuron_num, segmented_num, train_dataset): | |||
| def __init__(self, model, incremental=False, batch_size=32): | |||
| self._model = check_model('model', model, Model) | |||
| self._segmented_num = check_int_positive('segmented_num', segmented_num) | |||
| self._neuron_num = check_int_positive('neuron_num', neuron_num) | |||
| if self._neuron_num > 1e+9: | |||
| msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError would occur' | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
| self._lower_bounds = [np.inf]*self._neuron_num | |||
| self._upper_bounds = [-np.inf]*self._neuron_num | |||
| self._var = [0]*self._neuron_num | |||
| self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in range(self._neuron_num)] | |||
| self._lower_corner_hits = [0]*self._neuron_num | |||
| self._upper_corner_hits = [0]*self._neuron_num | |||
| self._bounds_get(train_dataset) | |||
| self._model_layer_dict = defaultdict(bool) | |||
| self._effective_model_layer_dict = defaultdict(bool) | |||
| def _set_init_effective_coverage_table(self, dataset): | |||
| self.incremental = check_param_type('incremental', incremental, bool) | |||
| self.batch_size = check_int_positive('batch_size', batch_size) | |||
| self._activate_table = defaultdict(list) | |||
| @abstractmethod | |||
| def get_metrics(self, dataset): | |||
| """ | |||
| Initialise the coverage table of each neuron in the model. | |||
| Calculate coverage metrics of given dataset. | |||
| Args: | |||
| dataset (numpy.ndarray): Dataset used for initialising the coverage table. | |||
| dataset (numpy.ndarray): Dataset used to calculate coverage metrics. | |||
| Raises: | |||
| NotImplementedError: It is an abstract method. | |||
| """ | |||
| self._model.predict(Tensor(dataset[0:1])) | |||
| tensors = _get_summary_tensor_data() | |||
| for name, tensor in tensors.items(): | |||
| if 'input' in name: | |||
| continue | |||
| for num_neuron in range(tensor.shape[1]): | |||
| self._model_layer_dict[(name, num_neuron)] = False | |||
| self._effective_model_layer_dict[(name, num_neuron)] = False | |||
| def _bounds_get(self, train_dataset, batch_size=32): | |||
| msg = 'The function get_metrics() is an abstract method in class `CoverageMetrics`, and should be' \ | |||
| ' implemented in child class.' | |||
| LOGGER.error(TAG, msg) | |||
| raise NotImplementedError(msg) | |||
| def _init_neuron_activate_table(self, data): | |||
| """ | |||
| Update the lower and upper boundaries of neurons' outputs. | |||
| Initialise the activate table of each neuron in the model with format: | |||
| {'layer1': [n1, n2, n3, ..., nn], 'layer2': [n1, n2, n3, ..., nn], ...} | |||
| Args: | |||
| train_dataset (numpy.ndarray): Training dataset used for | |||
| determine the neurons' output boundaries. | |||
| batch_size (int): The number of samples in a predict batch. | |||
| Default: 32. | |||
| data (numpy.ndarray): Data used for initialising the activate table. | |||
| Return: | |||
| dict, return a activate_table. | |||
| """ | |||
| batch_size = check_int_positive('batch_size', batch_size) | |||
| output_mat = [] | |||
| batches = train_dataset.shape[0] // batch_size | |||
| for i in range(batches): | |||
| inputs = train_dataset[i*batch_size: (i + 1)*batch_size] | |||
| output = self._model.predict(Tensor(inputs)).asnumpy() | |||
| output_mat.append(output) | |||
| lower_compare_array = np.concatenate([output, np.array([self._lower_bounds])], axis=0) | |||
| self._lower_bounds = np.min(lower_compare_array, axis=0) | |||
| upper_compare_array = np.concatenate([output, np.array([self._upper_bounds])], axis=0) | |||
| self._upper_bounds = np.max(upper_compare_array, axis=0) | |||
| if batches == 0: | |||
| output = self._model.predict(Tensor(train_dataset)).asnumpy() | |||
| self._lower_bounds = np.min(output, axis=0) | |||
| self._upper_bounds = np.max(output, axis=0) | |||
| output_mat.append(output) | |||
| self._var = np.std(np.concatenate(np.array(output_mat), axis=0), axis=0) | |||
| def _sections_hits_count(self, dataset, intervals): | |||
| self._model.predict(Tensor(data)) | |||
| layer_out = _get_summary_tensor_data() | |||
| if not layer_out: | |||
| msg = 'User must use TensorSummary() operation to specify the middle layer of the model participating in ' \ | |||
| 'the coverage calculation.' | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| activate_table = defaultdict() | |||
| for layer, value in layer_out.items(): | |||
| activate_table[layer] = np.zeros(value.shape[1], np.bool) | |||
| return activate_table | |||
| def _get_bounds(self, train_dataset): | |||
| """ | |||
| Update the coverage matrix of neurons' output subsections. | |||
| Update the lower and upper boundaries of neurons' outputs. | |||
| Args: | |||
| dataset (numpy.ndarray): Testing data. | |||
| intervals (list[float]): Segmentation intervals of neurons' outputs. | |||
| train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. | |||
| Return: | |||
| - numpy.ndarray, upper bounds of neuron' outputs. | |||
| - numpy.ndarray, lower bounds of neuron' outputs. | |||
| """ | |||
| dataset = check_numpy_param('dataset', dataset) | |||
| batch_output = self._model.predict(Tensor(dataset)).asnumpy() | |||
| batch_section_indexes = (batch_output - self._lower_bounds) // intervals | |||
| for section_indexes in batch_section_indexes: | |||
| for i in range(self._neuron_num): | |||
| if section_indexes[i] < 0: | |||
| self._lower_corner_hits[i] = 1 | |||
| elif section_indexes[i] >= self._segmented_num: | |||
| self._upper_corner_hits[i] = 1 | |||
| upper_bounds = defaultdict(list) | |||
| lower_bounds = defaultdict(list) | |||
| batches = math.ceil(train_dataset.shape[0] / self.batch_size) | |||
| for i in range(batches): | |||
| inputs = train_dataset[i * self.batch_size: (i + 1) * self.batch_size] | |||
| self._model.predict(Tensor(inputs)) | |||
| layer_out = _get_summary_tensor_data() | |||
| for layer, tensor in layer_out.items(): | |||
| value = tensor.asnumpy() | |||
| value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) | |||
| min_value = np.min(value, axis=0) | |||
| max_value = np.max(value, axis=0) | |||
| if np.any(upper_bounds[layer]): | |||
| max_flag = upper_bounds[layer] > max_value | |||
| min_flag = lower_bounds[layer] < min_value | |||
| upper_bounds[layer] = upper_bounds[layer] * max_flag + max_value * (1 - max_flag) | |||
| lower_bounds[layer] = lower_bounds[layer] * min_flag + min_value * (1 - min_flag) | |||
| else: | |||
| self._main_section_hits[i][int(section_indexes[i])] = 1 | |||
| upper_bounds[layer] = max_value | |||
| lower_bounds[layer] = min_value | |||
| return upper_bounds, lower_bounds | |||
| def _coverage_update(self, name, tensor, scaled_mean, scaled_rank, top_k, threshold): | |||
| def _activate_rate(self): | |||
| """ | |||
| Calculate the activate rate of neurons. | |||
| """ | |||
| Update the coverage matrix of neural coverage and effective neural coverage. | |||
| total_neurons = 0 | |||
| activated_neurons = 0 | |||
| for _, value in self._activate_table.items(): | |||
| activated_neurons += np.sum(value) | |||
| total_neurons += len(value) | |||
| activate_rate = activated_neurons / total_neurons | |||
| Args: | |||
| name (string): the name of the tensor. | |||
| tensor (tensor): the tensor in the network. | |||
| scaled_mean (numpy.ndarray): feature map of the tensor. | |||
| scaled_rank (numpy.ndarray): rank of tensor value. | |||
| top_k (int): neuron is covered when its output has the top k largest value in that hidden layer. | |||
| threshold (float): neuron is covered when its output is greater than the threshold. | |||
| return activate_rate | |||
| class NeuronCoverage(CoverageMetrics): | |||
| """ | |||
| Calculate the neurons activated coverage. Neuron is activated when its output is greater than the threshold. | |||
| Neuron coverage equals the proportion of activated neurons to total neurons in the network. | |||
| Args: | |||
| model (Model): The pre-trained model which waiting for testing. | |||
| threshold (float): Threshold used to determined neurons is activated or not. Default: 0.1. | |||
| incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | |||
| batch_size (int): The number of samples in a fuzz test batch. Default: 32. | |||
| """ | |||
| def __init__(self, model, threshold=0.1, incremental=False, batch_size=32): | |||
| super(NeuronCoverage, self).__init__(model, incremental, batch_size) | |||
| self.threshold = check_value_positive('threshold', threshold) | |||
| def get_metrics(self, dataset): | |||
| """ | |||
| for num_neuron in range(tensor.shape[1]): | |||
| if num_neuron >= (len(scaled_rank) - top_k) and not \ | |||
| self._effective_model_layer_dict[(name, scaled_rank[num_neuron])]: | |||
| self._effective_model_layer_dict[(name, scaled_rank[num_neuron])] = True | |||
| if scaled_mean[num_neuron] > threshold and not \ | |||
| self._model_layer_dict[(name, num_neuron)]: | |||
| self._model_layer_dict[(name, num_neuron)] = True | |||
| def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): | |||
| """ | |||
| Calculate the testing adequacy of the given dataset. | |||
| Get the metric of neuron coverage: the proportion of activated neurons to total neurons in the network. | |||
| Args: | |||
| dataset (numpy.ndarray): Data for fuzz test. | |||
| bias_coefficient (Union[int, float]): The coefficient used | |||
| for changing the neurons' output boundaries. Default: 0. | |||
| batch_size (int): The number of samples in a predict batch. Default: 32. | |||
| dataset (numpy.ndarray): Dataset used to calculate coverage metrics. | |||
| Returns: | |||
| float, the metric of 'neuron coverage'. | |||
| Examples: | |||
| >>> neuron_num = 10 | |||
| >>> segmented_num = 1000 | |||
| >>> batch_size = 32 | |||
| >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| >>> model_fuzz_test.calculate_coverage(test_images, top_k, threshold, batch_size) | |||
| >>> nc = NeuronCoverage(model, threshold=0.1) | |||
| >>> nc_metrics = nc.get_metrics(test_data) | |||
| """ | |||
| dataset = check_numpy_param('dataset', dataset) | |||
| batch_size = check_int_positive('batch_size', batch_size) | |||
| bias_coefficient = check_param_multi_types('bias_coefficient', bias_coefficient, [int, float]) | |||
| self._lower_bounds -= bias_coefficient*self._var | |||
| self._upper_bounds += bias_coefficient*self._var | |||
| intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num | |||
| batches = dataset.shape[0] // batch_size | |||
| batches = math.ceil(dataset.shape[0] / self.batch_size) | |||
| if not self.incremental or not self._activate_table: | |||
| self._activate_table = self._init_neuron_activate_table(dataset[0:1]) | |||
| for i in range(batches): | |||
| self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) | |||
| inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] | |||
| self._model.predict(Tensor(inputs)) | |||
| layer_out = _get_summary_tensor_data() | |||
| for layer, tensor in layer_out.items(): | |||
| value = tensor.asnumpy() | |||
| value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) | |||
| activate = np.sum(value > self.threshold, axis=0) > 0 | |||
| self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate) | |||
| neuron_coverage = self._activate_rate() | |||
| return neuron_coverage | |||
| class TopKNeuronCoverage(CoverageMetrics): | |||
| """ | |||
| Calculate the top k activated neurons coverage. Neuron is activated when its output has the top k largest value in | |||
| that hidden layers. Top k neurons coverage equals the proportion of activated neurons to total neurons in the | |||
| network. | |||
| Args: | |||
| model (Model): The pre-trained model which waiting for testing. | |||
| top_k (int): Neuron is activated when its output has the top k largest value in that hidden layers. Default: 3. | |||
| incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | |||
| batch_size (int): The number of samples in a fuzz test batch. Default: 32. | |||
| """ | |||
| def __init__(self, model, top_k=3, incremental=False, batch_size=32): | |||
| super(TopKNeuronCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) | |||
| self.top_k = check_int_positive('top_k', top_k) | |||
| def calculate_effective_coverage(self, dataset, top_k=3, threshold=0.1, batch_size=32): | |||
| def get_metrics(self, dataset): | |||
| """ | |||
| Calculate the effective testing adequacy of the given dataset. | |||
| In effective neural coverage, neuron is covered when its output has the top k largest value | |||
| in that hidden layers. In neural coverage, neuron is covered when its output is greater than the | |||
| threshold. Coverage equals the covered neurons divided by the total neurons in the network. | |||
| Get the metric of Top K activated neuron coverage. | |||
| Args: | |||
| threshold (float): neuron is covered when its output is greater than the threshold. | |||
| top_k (int): neuron is covered when its output has the top k largest value in that hiddern layer. | |||
| dataset (numpy.ndarray): Data for fuzz test. | |||
| dataset (numpy.ndarray): Dataset used to calculate coverage metrics. | |||
| Returns: | |||
| float, the metrics of 'top k neuron coverage'. | |||
| Examples: | |||
| >>> neuron_num = 10 | |||
| >>> segmented_num = 1000 | |||
| >>> top_k = 3 | |||
| >>> threshold = 0.1 | |||
| >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| >>> model_fuzz_test.calculate_coverage(test_images) | |||
| >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||
| >>> tknc = TopKNeuronCoverage(model, top_k=3) | |||
| >>> metrics = tknc.get_metrics(test_data) | |||
| """ | |||
| top_k = check_int_positive('top_k', top_k) | |||
| dataset = check_numpy_param('dataset', dataset) | |||
| batch_size = check_int_positive('batch_size', batch_size) | |||
| batches = dataset.shape[0] // batch_size | |||
| self._set_init_effective_coverage_table(dataset) | |||
| batches = math.ceil(dataset.shape[0] / self.batch_size) | |||
| if not self.incremental or not self._activate_table: | |||
| self._activate_table = self._init_neuron_activate_table(dataset[0:1]) | |||
| for i in range(batches): | |||
| inputs = dataset[i*batch_size: (i + 1)*batch_size] | |||
| self._model.predict(Tensor(inputs)).asnumpy() | |||
| tensors = _get_summary_tensor_data() | |||
| for name, tensor in tensors.items(): | |||
| if 'input' in name: | |||
| continue | |||
| scaled = tensor.asnumpy()[-1] | |||
| if scaled.ndim >= 3: # | |||
| scaled_mean = np.mean(scaled, axis=(1, 2)) | |||
| scaled_rank = np.argsort(scaled_mean) | |||
| self._coverage_update(name, tensor, scaled_mean, scaled_rank, top_k, threshold) | |||
| else: | |||
| scaled_rank = np.argsort(scaled) | |||
| self._coverage_update(name, tensor, scaled, scaled_rank, top_k, threshold) | |||
| inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] | |||
| self._model.predict(Tensor(inputs)) | |||
| layer_out = _get_summary_tensor_data() | |||
| for layer, tensor in layer_out.items(): | |||
| value = tensor.asnumpy() | |||
| if len(value.shape) > 2: | |||
| value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) | |||
| top_k_value = np.sort(value)[:, -self.top_k].reshape(value.shape[0], 1) | |||
| top_k_value = np.sum((value - top_k_value) >= 0, axis=0) > 0 | |||
| self._activate_table[layer] = np.logical_or(self._activate_table[layer], top_k_value) | |||
| top_k_neuron_coverage = self._activate_rate() | |||
| return top_k_neuron_coverage | |||
| class SuperNeuronActivateCoverage(CoverageMetrics): | |||
| """ | |||
| Get the metric of 'super neuron activation coverage'. :math:`SNAC = |UpperCornerNeuron|/|N|`. SNAC refers to the | |||
| proportion of neurons whose neurons output value in the test set exceeds the upper bounds of the corresponding | |||
| neurons output value in the training set. | |||
| def get_nc(self): | |||
| Args: | |||
| model (Model): The pre-trained model which waiting for testing. | |||
| train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. | |||
| incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | |||
| batch_size (int): The number of samples in a fuzz test batch. Default: 32. | |||
| """ | |||
| def __init__(self, model, train_dataset, incremental=False, batch_size=32): | |||
| super(SuperNeuronActivateCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) | |||
| train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
| self.upper_bounds, self.lower_bounds = self._get_bounds(train_dataset=train_dataset) | |||
| def get_metrics(self, dataset): | |||
| """ | |||
| Get the metric of 'neuron coverage'. | |||
| Get the metric of 'strong neuron activation coverage'. | |||
| Args: | |||
| dataset (numpy.ndarray): Dataset used to calculate coverage metrics. | |||
| Returns: | |||
| float, the metric of 'neuron coverage'. | |||
| float, the metric of 'strong neuron activation coverage'. | |||
| Examples: | |||
| >>> model_fuzz_test.get_nc() | |||
| >>> snac = SuperNeuronActivateCoverage(model, train_dataset) | |||
| >>> metrics = snac.get_metrics(test_data) | |||
| """ | |||
| covered_neurons = len([v for v in self._model_layer_dict.values() if v]) | |||
| total_neurons = len(self._model_layer_dict) | |||
| nc = covered_neurons / float(total_neurons) | |||
| return nc | |||
| dataset = check_numpy_param('dataset', dataset) | |||
| if not self.incremental or not self._activate_table: | |||
| self._activate_table = self._init_neuron_activate_table(dataset[0:1]) | |||
| batches = math.ceil(dataset.shape[0] / self.batch_size) | |||
| def get_effective_nc(self): | |||
| """ | |||
| Get the metric of 'effective neuron coverage'. | |||
| for i in range(batches): | |||
| inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] | |||
| self._model.predict(Tensor(inputs)) | |||
| layer_out = _get_summary_tensor_data() | |||
| for layer, tensor in layer_out.items(): | |||
| value = tensor.asnumpy() | |||
| if len(value.shape) > 2: | |||
| value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) | |||
| activate = np.sum(value > self.upper_bounds[layer], axis=0) > 0 | |||
| self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate) | |||
| snac = self._activate_rate() | |||
| return snac | |||
| Returns: | |||
| float, the metric of 'the effective neuron coverage'. | |||
| Examples: | |||
| >>> model_fuzz_test.get_effective_nc() | |||
| """ | |||
| covered_neurons = len([v for v in self._effective_model_layer_dict.values() if v]) | |||
| total_neurons = len(self._effective_model_layer_dict) | |||
| effective_nc = covered_neurons / float(total_neurons) | |||
| return effective_nc | |||
| class NeuronBoundsCoverage(SuperNeuronActivateCoverage): | |||
| """ | |||
| Get the metric of 'neuron boundary coverage' :math:`NBC = (|UpperCornerNeuron| + |LowerCornerNeuron|)/(2*|N|)`, | |||
| where :math`|N|` is the number of neurons, NBC refers to the proportion of neurons whose neurons output value in | |||
| the test dataset exceeds the upper and lower bounds of the corresponding neurons output value in the training | |||
| dataset. | |||
| def get_kmnc(self): | |||
| """ | |||
| Get the metric of 'k-multisection neuron coverage'. KMNC measures how | |||
| thoroughly the given set of test inputs covers the range of neurons | |||
| output values derived from training dataset. | |||
| Args: | |||
| model (Model): The pre-trained model which waiting for testing. | |||
| train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. | |||
| incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | |||
| batch_size (int): The number of samples in a fuzz test batch. Default: 32. | |||
| """ | |||
| Returns: | |||
| float, the metric of 'k-multisection neuron coverage'. | |||
| def __init__(self, model, train_dataset, incremental=False, batch_size=32): | |||
| super(NeuronBoundsCoverage, self).__init__(model, train_dataset, incremental=incremental, batch_size=batch_size) | |||
| Examples: | |||
| >>> model_fuzz_test.get_kmnc() | |||
| def get_metrics(self, dataset): | |||
| """ | |||
| kmnc = np.sum(self._main_section_hits) / (self._neuron_num*self._segmented_num) | |||
| return kmnc | |||
| Get the metric of 'neuron boundary coverage'. | |||
| def get_nbc(self): | |||
| """ | |||
| Get the metric of 'neuron boundary coverage' :math:`NBC = (|UpperCornerNeuron| | |||
| + |LowerCornerNeuron|)/(2*|N|)`, where :math`|N|` is the number of neurons, | |||
| NBC refers to the proportion of neurons whose neurons output value in | |||
| the test dataset exceeds the upper and lower bounds of the corresponding | |||
| neurons output value in the training dataset. | |||
| Args: | |||
| dataset (numpy.ndarray): Dataset used to calculate coverage metrics. | |||
| Returns: | |||
| float, the metric of 'neuron boundary coverage'. | |||
| Examples: | |||
| >>> model_fuzz_test.get_nbc() | |||
| >>> nbc = NeuronBoundsCoverage(model, train_dataset) | |||
| >>> metrics = nbc.get_metrics(test_data) | |||
| """ | |||
| nbc = (np.sum(self._lower_corner_hits) + np.sum(self._upper_corner_hits)) / (2*self._neuron_num) | |||
| dataset = check_numpy_param('dataset', dataset) | |||
| if not self.incremental or not self._activate_table: | |||
| self._activate_table = self._init_neuron_activate_table(dataset[0:1]) | |||
| batches = math.ceil(dataset.shape[0] / self.batch_size) | |||
| for i in range(batches): | |||
| inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] | |||
| self._model.predict(Tensor(inputs)) | |||
| layer_out = _get_summary_tensor_data() | |||
| for layer, tensor in layer_out.items(): | |||
| value = tensor.asnumpy() | |||
| if len(value.shape) > 2: | |||
| value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) | |||
| outer = np.logical_or(value > self.upper_bounds[layer], value < self.lower_bounds[layer]) | |||
| activate = np.sum(outer, axis=0) > 0 | |||
| self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate) | |||
| nbc = self._activate_rate() | |||
| return nbc | |||
| def get_snac(self): | |||
| class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage): | |||
| """ | |||
| Get the metric of 'k-multisection neuron coverage'. KMNC measures how thoroughly the given set of test inputs | |||
| covers the range of neurons output values derived from training dataset. | |||
| Args: | |||
| model (Model): The pre-trained model which waiting for testing. | |||
| train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. | |||
| segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 100. | |||
| incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | |||
| batch_size (int): The number of samples in a fuzz test batch. Default: 32. | |||
| """ | |||
| def __init__(self, model, train_dataset, segmented_num=100, incremental=False, batch_size=32): | |||
| super(KMultisectionNeuronCoverage, self).__init__(model, train_dataset, incremental=incremental, | |||
| batch_size=batch_size) | |||
| self.segmented_num = check_int_positive('segmented_num', segmented_num) | |||
| self.intervals = defaultdict(list) | |||
| for keys in self.upper_bounds.keys(): | |||
| self.intervals[keys] = (self.upper_bounds[keys] - self.lower_bounds[keys]) / self.segmented_num | |||
| def _init_k_multisection_table(self, data): | |||
| """ Initial the activate table.""" | |||
| self._model.predict(Tensor(data)) | |||
| layer_out = _get_summary_tensor_data() | |||
| activate_section_table = defaultdict() | |||
| for layer, value in layer_out.items(): | |||
| activate_section_table[layer] = np.zeros((value.shape[1], self.segmented_num), np.bool) | |||
| return activate_section_table | |||
| def get_metrics(self, dataset): | |||
| """ | |||
| Get the metric of 'strong neuron activation coverage'. | |||
| :math:`SNAC = |UpperCornerNeuron|/|N|`. SNAC refers to the proportion | |||
| of neurons whose neurons output value in the test set exceeds the upper | |||
| bounds of the corresponding neurons output value in the training set. | |||
| Get the metric of 'k-multisection neuron coverage'. | |||
| Args: | |||
| dataset (numpy.ndarray): Dataset used to calculate coverage metrics. | |||
| Returns: | |||
| float, the metric of 'strong neuron activation coverage'. | |||
| float, the metric of 'k-multisection neuron coverage'. | |||
| Examples: | |||
| >>> model_fuzz_test.get_snac() | |||
| >>> kmnc = KMultisectionNeuronCoverage(model, train_dataset, segmented_num=100) | |||
| >>> metrics = kmnc.get_metrics(test_data) | |||
| """ | |||
| snac = np.sum(self._upper_corner_hits) / self._neuron_num | |||
| return snac | |||
| dataset = check_numpy_param('dataset', dataset) | |||
| if not self.incremental or not self._activate_table: | |||
| self._activate_table = self._init_k_multisection_table(dataset[0:1]) | |||
| batches = math.ceil(dataset.shape[0] / self.batch_size) | |||
| for i in range(batches): | |||
| inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] | |||
| self._model.predict(Tensor(inputs)) | |||
| layer_out = _get_summary_tensor_data() | |||
| for layer, tensor in layer_out.items(): | |||
| value = tensor.asnumpy() | |||
| value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) | |||
| hits = np.floor((value - self.lower_bounds[layer]) / self.intervals[layer]).astype(int) | |||
| hits = np.transpose(hits, [1, 0]) | |||
| for n in range(len(hits)): | |||
| for sec in hits[n]: | |||
| if sec >= self.segmented_num or sec < 0: | |||
| continue | |||
| self._activate_table[layer][n][sec] = True | |||
| kmnc = self._activate_rate() / self.segmented_num | |||
| return kmnc | |||
| @@ -39,9 +39,7 @@ def _check_array_not_empty(arg_name, arg_value): | |||
| def check_param_type(arg_name, arg_value, valid_type): | |||
| """Check parameter type.""" | |||
| if not isinstance(arg_value, valid_type): | |||
| msg = '{} must be {}, but got {}'.format(arg_name, | |||
| valid_type, | |||
| type(arg_value).__name__) | |||
| msg = '{} must be {}, but got {}'.format(arg_name, valid_type, type(arg_value).__name__) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| @@ -51,8 +49,7 @@ def check_param_type(arg_name, arg_value, valid_type): | |||
| def check_param_multi_types(arg_name, arg_value, valid_types): | |||
| """Check parameter multi types.""" | |||
| if not isinstance(arg_value, tuple(valid_types)): | |||
| msg = 'type of {} must be in {}, but got {}' \ | |||
| .format(arg_name, valid_types, type(arg_value).__name__) | |||
| msg = 'type of {} must be in {}, but got {}'.format(arg_name, valid_types, type(arg_value).__name__) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| @@ -68,8 +65,7 @@ def check_int_positive(arg_name, arg_value): | |||
| raise ValueError(msg) | |||
| arg_value = check_param_type(arg_name, arg_value, int) | |||
| if arg_value <= 0: | |||
| msg = '{} must be greater than 0, but got {}'.format(arg_name, | |||
| arg_value) | |||
| msg = '{} must be greater than 0, but got {}'.format(arg_name, arg_value) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| return arg_value | |||
| @@ -79,8 +75,7 @@ def check_value_non_negative(arg_name, arg_value): | |||
| """Check non negative value.""" | |||
| arg_value = check_param_multi_types(arg_name, arg_value, (int, float)) | |||
| if float(arg_value) < 0.0: | |||
| msg = '{} must not be less than 0, but got {}'.format(arg_name, | |||
| arg_value) | |||
| msg = '{} must not be less than 0, but got {}'.format(arg_name, arg_value) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| return arg_value | |||
| @@ -90,8 +85,7 @@ def check_value_positive(arg_name, arg_value): | |||
| """Check positive value.""" | |||
| arg_value = check_param_multi_types(arg_name, arg_value, (int, float)) | |||
| if float(arg_value) <= 0.0: | |||
| msg = '{} must be greater than zero, but got {}'.format(arg_name, | |||
| arg_value) | |||
| msg = '{} must be greater than zero, but got {}'.format(arg_name, arg_value) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| return arg_value | |||
| @@ -102,10 +96,7 @@ def check_param_in_range(arg_name, arg_value, lower, upper): | |||
| Check range of parameter. | |||
| """ | |||
| if arg_value <= lower or arg_value >= upper: | |||
| msg = '{} must be between {} and {}, but got {}'.format(arg_name, | |||
| lower, | |||
| upper, | |||
| arg_value) | |||
| msg = '{} must be between {} and {}, but got {}'.format(arg_name, lower, upper, arg_value) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| @@ -129,10 +120,7 @@ def check_model(model_name, model, model_type): | |||
| """ | |||
| if isinstance(model, model_type): | |||
| return model | |||
| msg = '{} should be an instance of {}, but got {}' \ | |||
| .format(model_name, | |||
| model_type, | |||
| type(model).__name__) | |||
| msg = '{} should be an instance of {}, but got {}'.format(model_name, model_type, type(model).__name__) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| @@ -175,11 +163,9 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels): | |||
| labels (numpy.ndarray): Labels of `inputs`. | |||
| Returns: | |||
| - numpy.ndarray, if `inputs` 's dimension equals to | |||
| `labels`, return inputs with type of numpy.ndarray. | |||
| - numpy.ndarray, if `inputs` 's dimension equals to `labels`, return inputs with type of numpy.ndarray. | |||
| - numpy.ndarray, if `inputs` 's dimension equals to | |||
| `labels` , return labels with type of numpy.ndarray. | |||
| - numpy.ndarray, if `inputs` 's dimension equals to `labels` , return labels with type of numpy.ndarray. | |||
| Raises: | |||
| ValueError: If inputs.shape[0] is not equal to labels.shape[0]. | |||
| @@ -188,8 +174,7 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels): | |||
| labels = check_numpy_param(labels_name, labels) | |||
| if inputs.shape[0] != labels.shape[0]: | |||
| msg = '{} shape[0] must equal {} shape[0], bot got shape of ' \ | |||
| 'inputs {}, shape of labels {}'.format(inputs_name, labels_name, | |||
| inputs.shape, labels.shape) | |||
| 'inputs {}, shape of labels {}'.format(inputs_name, labels_name, inputs.shape, labels.shape) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| return inputs, labels | |||
| @@ -198,10 +183,8 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels): | |||
| def check_equal_length(para_name1, value1, para_name2, value2): | |||
| """Check weather the two parameters have equal length.""" | |||
| if len(value1) != len(value2): | |||
| msg = 'The dimension of {0} must equal to the ' \ | |||
| '{1}, but got {0} is {2}, ' \ | |||
| '{1} is {3}'.format(para_name1, para_name2, len(value1), | |||
| len(value2)) | |||
| msg = 'The dimension of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'\ | |||
| .format(para_name1, para_name2, len(value1), len(value2)) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| return value1, value2 | |||
| @@ -210,10 +193,8 @@ def check_equal_length(para_name1, value1, para_name2, value2): | |||
| def check_equal_shape(para_name1, value1, para_name2, value2): | |||
| """Check weather the two parameters have equal shape.""" | |||
| if value1.shape != value2.shape: | |||
| msg = 'The shape of {0} must equal to the ' \ | |||
| '{1}, but got {0} is {2}, ' \ | |||
| '{1} is {3}'.format(para_name1, para_name2, value1.shape, | |||
| value2.shape) | |||
| msg = 'The shape of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'.\ | |||
| format(para_name1, para_name2, value1.shape, value2.shape) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| return value1, value2 | |||
| @@ -225,8 +206,7 @@ def check_norm_level(norm_level): | |||
| msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level)) | |||
| accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', np.inf] | |||
| if norm_level not in accept_norm: | |||
| msg = 'norm_level must be in {}, but got {}'.format(accept_norm, | |||
| norm_level) | |||
| msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| return norm_level | |||
| @@ -252,20 +232,16 @@ def normalize_value(value, norm_level): | |||
| value_reshape = value.reshape((value.shape[0], -1)) | |||
| avoid_zero_div = 1e-12 | |||
| if norm_level in (1, '1', 'l1'): | |||
| norm = np.linalg.norm(value_reshape, ord=1, axis=1, keepdims=True) + \ | |||
| avoid_zero_div | |||
| norm = np.linalg.norm(value_reshape, ord=1, axis=1, keepdims=True) + avoid_zero_div | |||
| norm_value = value_reshape / norm | |||
| elif norm_level in (2, '2', 'l2'): | |||
| norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + \ | |||
| avoid_zero_div | |||
| norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div | |||
| norm_value = value_reshape / norm | |||
| elif norm_level in (np.inf, 'inf'): | |||
| norm = np.max(abs(value_reshape), axis=1, keepdims=True) + \ | |||
| avoid_zero_div | |||
| norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div | |||
| norm_value = value_reshape / norm | |||
| else: | |||
| msg = 'Values of `norm_level` different from 1, 2 and ' \ | |||
| '`np.inf` are currently not supported, but got {}.' \ | |||
| msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are currently not supported, but got {}.' \ | |||
| .format(norm_level) | |||
| LOGGER.error(TAG, msg) | |||
| raise NotImplementedError(msg) | |||
| @@ -339,13 +315,30 @@ def check_inputs_labels(inputs, labels): | |||
| inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||
| if isinstance(inputs, tuple): | |||
| for i, inputs_item in enumerate(inputs): | |||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
| 'inputs[{}]'.format(i), inputs_item) | |||
| _ = check_pair_numpy_param('inputs_image', inputs_image, 'inputs[{}]'.format(i), inputs_item) | |||
| if isinstance(labels, tuple): | |||
| for i, labels_item in enumerate(labels): | |||
| _ = check_pair_numpy_param('inputs', inputs_image, \ | |||
| 'labels[{}]'.format(i), labels_item) | |||
| _ = check_pair_numpy_param('inputs', inputs_image, 'labels[{}]'.format(i), labels_item) | |||
| else: | |||
| _ = check_pair_numpy_param('inputs', inputs_image, \ | |||
| 'labels', labels) | |||
| _ = check_pair_numpy_param('inputs', inputs_image, 'labels', labels) | |||
| return inputs_image, inputs, labels | |||
| def check_param_bounds(arg_name, arg_value): | |||
| """Check bounds is valid""" | |||
| arg_value = check_param_multi_types(arg_name, arg_value, [tuple, list]) | |||
| if len(arg_value) != 2: | |||
| msg = 'length of {0} must be 2, but got length of {0} is {1}'.format(arg_name, len(arg_value)) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| for i, b in enumerate(arg_value): | |||
| if not isinstance(b, (float, int)): | |||
| msg = 'each value in {} must be int or float, but got the {}th value is {}'.format(arg_name, i, b) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| if arg_value[0] > arg_value[1]: | |||
| msg = "lower boundary cannot be greater than upper boundary, corresponding values in {} are {} and {}". \ | |||
| format(arg_name, arg_value[0], arg_value[1]) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| return arg_value | |||
| @@ -25,7 +25,8 @@ from mindspore.ops import TensorSummary | |||
| from mindarmour.adv_robustness.attacks import FastGradientSignMethod | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.fuzz_testing import ModelCoverageMetrics | |||
| from mindarmour.fuzz_testing import NeuronCoverage, TopKNeuronCoverage, SuperNeuronActivateCoverage, \ | |||
| NeuronBoundsCoverage, KMultisectionNeuronCoverage | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Neuron coverage test' | |||
| @@ -74,39 +75,48 @@ def test_lenet_mnist_coverage_cpu(): | |||
| model = Model(net) | |||
| # initialize fuzz test with training dataset | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| top_k = 3 | |||
| threshold = 0.1 | |||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||
| test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | |||
| model_fuzz_test.calculate_coverage(test_data) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| nc = NeuronCoverage(model, threshold=0.1) | |||
| nc_metric = nc.get_metrics(test_data) | |||
| tknc = TopKNeuronCoverage(model, top_k=3) | |||
| tknc_metrics = tknc.get_metrics(test_data) | |||
| snac = SuperNeuronActivateCoverage(model, training_data) | |||
| snac_metrics = snac.get_metrics(test_data) | |||
| nbc = NeuronBoundsCoverage(model, training_data) | |||
| nbc_metrics = nbc.get_metrics(test_data) | |||
| kmnc = KMultisectionNeuronCoverage(model, training_data, segmented_num=100) | |||
| kmnc_metrics = kmnc.get_metrics(test_data) | |||
| print('KMNC of this test is: ', kmnc_metrics) | |||
| print('NBC of this test is: ', nbc_metrics) | |||
| print('SNAC of this test is: ', snac_metrics) | |||
| print('NC of this test is: ', nc_metric) | |||
| print('TKNC of this test is: ', tknc_metrics) | |||
| # generate adv_data | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True) | |||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | |||
| adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||
| model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| nc_metric = nc.get_metrics(adv_data) | |||
| tknc_metrics = tknc.get_metrics(adv_data) | |||
| snac_metrics = snac.get_metrics(adv_data) | |||
| nbc_metrics = nbc.get_metrics(adv_data) | |||
| kmnc_metrics = kmnc.get_metrics(adv_data) | |||
| print('KMNC of adv data is: ', kmnc_metrics) | |||
| print('NBC of adv data is: ', nbc_metrics) | |||
| print('SNAC of adv data is: ', snac_metrics) | |||
| print('NC of adv data is: ', nc_metric) | |||
| print('TKNC of adv data is: ', tknc_metrics) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @@ -120,35 +130,28 @@ def test_lenet_mnist_coverage_ascend(): | |||
| model = Model(net) | |||
| # initialize fuzz test with training dataset | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| top_k = 3 | |||
| threshold = 0.1 | |||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||
| test_labels = np.random.randint(0, 10, 2000) | |||
| test_labels = (np.eye(10)[test_labels]).astype(np.float32) | |||
| model_fuzz_test.calculate_coverage(test_data) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| nc = NeuronCoverage(model, threshold=0.1) | |||
| nc_metric = nc.get_metrics(test_data) | |||
| model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| tknc = TopKNeuronCoverage(model, top_k=3) | |||
| tknc_metrics = tknc.get_metrics(test_data) | |||
| # generate adv_data | |||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
| adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||
| model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| snac = SuperNeuronActivateCoverage(model, training_data) | |||
| snac_metrics = snac.get_metrics(test_data) | |||
| nbc = NeuronBoundsCoverage(model, training_data) | |||
| nbc_metrics = nbc.get_metrics(test_data) | |||
| kmnc = KMultisectionNeuronCoverage(model, training_data, segmented_num=100) | |||
| kmnc_metrics = kmnc.get_metrics(test_data) | |||
| print('KMNC of this test is: ', kmnc_metrics) | |||
| print('NBC of this test is: ', nbc_metrics) | |||
| print('SNAC of this test is: ', snac_metrics) | |||
| print('NC of this test is: ', nc_metric) | |||
| print('TKNC of this test is: ', tknc_metrics) | |||
| @@ -21,9 +21,10 @@ from mindspore import nn | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train import Model | |||
| from mindspore.ops import TensorSummary | |||
| from mindarmour.fuzz_testing import Fuzzer | |||
| from mindarmour.fuzz_testing import ModelCoverageMetrics | |||
| from mindarmour.fuzz_testing import KMultisectionNeuronCoverage | |||
| from mindarmour.utils.logger import LogUtil | |||
| LOGGER = LogUtil.get_instance() | |||
| @@ -52,30 +53,37 @@ class Net(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.conv1 = conv(1, 6, 5) | |||
| self.conv2 = conv(6, 16, 5) | |||
| self.fc1 = fc_with_initialize(16*5*5, 120) | |||
| self.fc1 = fc_with_initialize(16 * 5 * 5, 120) | |||
| self.fc2 = fc_with_initialize(120, 84) | |||
| self.fc3 = fc_with_initialize(84, 10) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.reshape = P.Reshape() | |||
| self.summary = TensorSummary() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.relu(x) | |||
| self.summary('conv1', x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.relu(x) | |||
| self.summary('conv2', x) | |||
| x = self.max_pool2d(x) | |||
| x = self.reshape(x, (-1, 16*5*5)) | |||
| x = self.reshape(x, (-1, 16 * 5 * 5)) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| self.summary('fc1', x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| self.summary('fc2', x) | |||
| x = self.fc3(x) | |||
| self.summary('fc3', x) | |||
| return x | |||
| @@ -100,12 +108,8 @@ def test_fuzzing_ascend(): | |||
| {'method': 'FGSM', | |||
| 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} | |||
| ] | |||
| # initialize fuzz test with training dataset | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
| model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
| @@ -118,13 +122,12 @@ def test_fuzzing_ascend(): | |||
| initial_seeds.append([img, label]) | |||
| initial_seeds = initial_seeds[:100] | |||
| model_coverage_test.calculate_coverage( | |||
| np.array(test_images[:100]).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) | |||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||
| nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) | |||
| cn_metrics = nc.get_metrics(test_images[:100]) | |||
| print('neuron coverage of initial seeds is: ', cn_metrics) | |||
| model_fuzz_test = Fuzzer(model) | |||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, nc, max_iters=100) | |||
| print(metrics) | |||
| @@ -139,8 +142,6 @@ def test_fuzzing_cpu(): | |||
| model = Model(net) | |||
| batch_size = 8 | |||
| num_classe = 10 | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| mutate_config = [{'method': 'Blur', | |||
| 'params': {'auto_param': [True]}}, | |||
| {'method': 'Contrast', | |||
| @@ -152,7 +153,6 @@ def test_fuzzing_cpu(): | |||
| ] | |||
| # initialize fuzz test with training dataset | |||
| train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
| model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| @@ -166,11 +166,9 @@ def test_fuzzing_cpu(): | |||
| initial_seeds.append([img, label]) | |||
| initial_seeds = initial_seeds[:100] | |||
| model_coverage_test.calculate_coverage( | |||
| np.array(test_images[:100]).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) | |||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | |||
| nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) | |||
| tknc_metrics = nc.get_metrics(test_images[:100]) | |||
| print('neuron coverage of initial seeds is: ', tknc_metrics) | |||
| model_fuzz_test = Fuzzer(model) | |||
| _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, nc, max_iters=100) | |||
| print(metrics) | |||