Merge pull request !254 from RyanZ/test_coverage3tags/v1.5.0-rc1
| @@ -1,24 +1,32 @@ | |||
| # Application demos of model fuzzing | |||
| ## Introduction | |||
| The same as the traditional software fuzz testing, we can also design fuzz test for AI models. Compared to | |||
| branch coverage or line coverage of traditional software, some people propose the | |||
| concept of 'neuron coverage' based on the unique structure of deep neural network. We can use the neuron coverage | |||
| as a guide to search more metamorphic inputs to test our models. | |||
| ## 1. calculation of neuron coverage | |||
| There are three metrics proposed for evaluating the neuron coverage of a test:KMNC, NBC and SNAC. Usually we need to | |||
| feed all the training dataset into the model first, and record the output range of all neurons (however, only the last | |||
| layer of neurons are recorded in our method). In the testing phase, we feed test samples into the model, and | |||
| calculate those three metrics mentioned above according to those neurons' output distribution. | |||
| ## 1. calculation of neuron coverage | |||
| There are five metrics proposed for evaluating the neuron coverage of a test:NC, Effective NC, KMNC, NBC and SNAC. | |||
| Usually we need to feed all the training dataset into the model first, and record the output range of all neurons | |||
| (however, in KMNC, NBC and SNAC, only the last layer of neurons are recorded in our method). In the testing phase, | |||
| we feed test samples into the model, and calculate those three metrics mentioned above according to those neurons' | |||
| output distribution. | |||
| ```sh | |||
| $ cd examples/ai_fuzzer/ | |||
| $ python lenet5_mnist_coverage.py | |||
| cd examples/ai_fuzzer/ | |||
| python lenet5_mnist_coverage.py | |||
| ``` | |||
| ## 2. fuzz test for AI model | |||
| ## 2. fuzz test for AI model | |||
| We have provided several types of methods for manipulating metamorphic inputs: affine transformation, pixel | |||
| transformation and adversarial attacks. Usually we feed the original samples into the fuzz function as seeds, and | |||
| then metamorphic samples are generated through iterative manipulations. | |||
| ```sh | |||
| $ cd examples/ai_fuzzer/ | |||
| $ python lenet5_mnist_fuzzing.py | |||
| cd examples/ai_fuzzer/ | |||
| python lenet5_mnist_fuzzing.py | |||
| ``` | |||
| @@ -31,7 +31,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics | |||
| from mindarmour.utils.logger import LogUtil | |||
| from examples.common.dataset.data_processing import generate_mnist_dataset | |||
| from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||
| from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Fuzz_testing and enhance model' | |||
| @@ -75,9 +75,11 @@ def example_lenet_mnist_fuzzing(): | |||
| images = data[0].astype(np.float32) | |||
| train_images.append(images) | |||
| train_images = np.concatenate(train_images, axis=0) | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| # initialize fuzz test with training dataset | |||
| model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||
| model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| @@ -22,7 +22,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics | |||
| from mindarmour.utils.logger import LogUtil | |||
| from examples.common.dataset.data_processing import generate_mnist_dataset | |||
| from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||
| from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Neuron coverage test' | |||
| @@ -46,9 +46,13 @@ def test_lenet_mnist_coverage(): | |||
| images = data[0].astype(np.float32) | |||
| train_images.append(images) | |||
| train_images = np.concatenate(train_images, axis=0) | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| top_k = 3 | |||
| threshold = 0.1 | |||
| # initialize fuzz test with training dataset | |||
| model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| @@ -69,6 +73,10 @@ def test_lenet_mnist_coverage(): | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| # generate adv_data | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True) | |||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | |||
| @@ -78,6 +86,10 @@ def test_lenet_mnist_coverage(): | |||
| LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| if __name__ == '__main__': | |||
| # device_target can be "CPU", "GPU" or "Ascend" | |||
| @@ -21,7 +21,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics | |||
| from mindarmour.utils.logger import LogUtil | |||
| from examples.common.dataset.data_processing import generate_mnist_dataset | |||
| from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||
| from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Fuzz_test' | |||
| @@ -0,0 +1,99 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| lenet network with summary | |||
| """ | |||
| from mindspore import nn | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from mindspore.ops import TensorSummary | |||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
| weight = weight_variable() | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||
| def fc_with_initialize(input_channels, out_channels): | |||
| weight = weight_variable() | |||
| bias = weight_variable() | |||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||
| def weight_variable(): | |||
| return TruncatedNormal(0.05) | |||
| class LeNet5(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| """ | |||
| def __init__(self): | |||
| super(LeNet5, self).__init__() | |||
| self.conv1 = conv(1, 6, 5) | |||
| self.conv2 = conv(6, 16, 5) | |||
| self.fc1 = fc_with_initialize(16*5*5, 120) | |||
| self.fc2 = fc_with_initialize(120, 84) | |||
| self.fc3 = fc_with_initialize(84, 10) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flatten = nn.Flatten() | |||
| self.summary = TensorSummary() | |||
| def construct(self, x): | |||
| """ | |||
| construct the network architecture | |||
| Returns: | |||
| x (tensor): network output | |||
| """ | |||
| self.summary('input', x) | |||
| x = self.conv1(x) | |||
| self.summary('1', x) | |||
| x = self.relu(x) | |||
| self.summary('2', x) | |||
| x = self.max_pool2d(x) | |||
| self.summary('3', x) | |||
| x = self.conv2(x) | |||
| self.summary('4', x) | |||
| x = self.relu(x) | |||
| self.summary('5', x) | |||
| x = self.max_pool2d(x) | |||
| self.summary('6', x) | |||
| x = self.flatten(x) | |||
| self.summary('7', x) | |||
| x = self.fc1(x) | |||
| self.summary('8', x) | |||
| x = self.relu(x) | |||
| self.summary('9', x) | |||
| x = self.fc2(x) | |||
| self.summary('10', x) | |||
| x = self.relu(x) | |||
| self.summary('11', x) | |||
| x = self.fc3(x) | |||
| self.summary('output', x) | |||
| return x | |||
| @@ -15,10 +15,12 @@ | |||
| Model-Test Coverage Metrics. | |||
| """ | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore import Model | |||
| from mindspore.train.summary.summary_record import _get_summary_tensor_data | |||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
| check_int_positive, check_param_multi_types | |||
| @@ -63,6 +65,9 @@ class ModelCoverageMetrics: | |||
| >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
| >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||
| >>> print('NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| >>> print('Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| """ | |||
| def __init__(self, model, neuron_num, segmented_num, train_dataset): | |||
| @@ -81,6 +86,24 @@ class ModelCoverageMetrics: | |||
| self._lower_corner_hits = [0]*self._neuron_num | |||
| self._upper_corner_hits = [0]*self._neuron_num | |||
| self._bounds_get(train_dataset) | |||
| self._model_layer_dict = defaultdict(bool) | |||
| self._effective_model_layer_dict = defaultdict(bool) | |||
| def _set_init_effective_coverage_table(self, dataset): | |||
| """ | |||
| Initialise the coverage table of each neuron in the model. | |||
| Args: | |||
| dataset (numpy.ndarray): Dataset used for initialising the coverage table. | |||
| """ | |||
| self._model.predict(Tensor(dataset[0:1])) | |||
| tensors = _get_summary_tensor_data() | |||
| for name, tensor in tensors.items(): | |||
| if 'input' in name: | |||
| continue | |||
| for num_neuron in range(tensor.shape[1]): | |||
| self._model_layer_dict[(name, num_neuron)] = False | |||
| self._effective_model_layer_dict[(name, num_neuron)] = False | |||
| def _bounds_get(self, train_dataset, batch_size=32): | |||
| """ | |||
| @@ -130,6 +153,27 @@ class ModelCoverageMetrics: | |||
| else: | |||
| self._main_section_hits[i][int(section_indexes[i])] = 1 | |||
| def _coverage_update(self, name, tensor, scaled_mean, scaled_rank, top_k, threshold): | |||
| """ | |||
| Update the coverage matrix of neural coverage and effective neural coverage. | |||
| Args: | |||
| name (string): the name of the tensor. | |||
| tensor (tensor): the tensor in the network. | |||
| scaled_mean (numpy.ndarray): feature map of the tensor. | |||
| scaled_rank (numpy.ndarray): rank of tensor value. | |||
| top_k (int): neuron is covered when its output has the top k largest value in that hidden layer. | |||
| threshold (float): neuron is covered when its output is greater than the threshold. | |||
| """ | |||
| for num_neuron in range(tensor.shape[1]): | |||
| if num_neuron >= (len(scaled_rank) - top_k) and not \ | |||
| self._effective_model_layer_dict[(name, scaled_rank[num_neuron])]: | |||
| self._effective_model_layer_dict[(name, scaled_rank[num_neuron])] = True | |||
| if scaled_mean[num_neuron] > threshold and not \ | |||
| self._model_layer_dict[(name, num_neuron)]: | |||
| self._model_layer_dict[(name, num_neuron)] = True | |||
| def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): | |||
| """ | |||
| Calculate the testing adequacy of the given dataset. | |||
| @@ -143,8 +187,9 @@ class ModelCoverageMetrics: | |||
| Examples: | |||
| >>> neuron_num = 10 | |||
| >>> segmented_num = 1000 | |||
| >>> batch_size = 32 | |||
| >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| >>> model_fuzz_test.calculate_coverage(test_images) | |||
| >>> model_fuzz_test.calculate_coverage(test_images, top_k, threshold, batch_size) | |||
| """ | |||
| dataset = check_numpy_param('dataset', dataset) | |||
| @@ -157,6 +202,79 @@ class ModelCoverageMetrics: | |||
| for i in range(batches): | |||
| self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) | |||
| def calculate_effective_coverage(self, dataset, top_k=3, threshold=0.1, batch_size=32): | |||
| """ | |||
| Calculate the effective testing adequacy of the given dataset. | |||
| In effective neural coverage, neuron is covered when its output has the top k largest value | |||
| in that hidden layers. In neural coverage, neuron is covered when its output is greater than the | |||
| threshold. Coverage equals the covered neurons divided by the total neurons in the network. | |||
| Args: | |||
| threshold (float): neuron is covered when its output is greater than the threshold. | |||
| top_k (int): neuron is covered when its output has the top k largest value in that hiddern layer. | |||
| dataset (numpy.ndarray): Data for fuzz test. | |||
| Examples: | |||
| >>> neuron_num = 10 | |||
| >>> segmented_num = 1000 | |||
| >>> top_k = 3 | |||
| >>> threshold = 0.1 | |||
| >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||
| >>> model_fuzz_test.calculate_coverage(test_images) | |||
| >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||
| """ | |||
| top_k = check_int_positive('top_k', top_k) | |||
| dataset = check_numpy_param('dataset', dataset) | |||
| batch_size = check_int_positive('batch_size', batch_size) | |||
| batches = dataset.shape[0] // batch_size | |||
| self._set_init_effective_coverage_table(dataset) | |||
| for i in range(batches): | |||
| inputs = dataset[i*batch_size: (i + 1)*batch_size] | |||
| self._model.predict(Tensor(inputs)).asnumpy() | |||
| tensors = _get_summary_tensor_data() | |||
| for name, tensor in tensors.items(): | |||
| if 'input' in name: | |||
| continue | |||
| scaled = tensor.asnumpy()[-1] | |||
| if scaled.ndim >= 3: # | |||
| scaled_mean = np.mean(scaled, axis=(1, 2)) | |||
| scaled_rank = np.argsort(scaled_mean) | |||
| self._coverage_update(name, tensor, scaled_mean, scaled_rank, top_k, threshold) | |||
| else: | |||
| scaled_rank = np.argsort(scaled) | |||
| self._coverage_update(name, tensor, scaled, scaled_rank, top_k, threshold) | |||
| def get_nc(self): | |||
| """ | |||
| Get the metric of 'neuron coverage'. | |||
| Returns: | |||
| float, the metric of 'neuron coverage'. | |||
| Examples: | |||
| >>> model_fuzz_test.get_nc() | |||
| """ | |||
| covered_neurons = len([v for v in self._model_layer_dict.values() if v]) | |||
| total_neurons = len(self._model_layer_dict) | |||
| nc = covered_neurons / float(total_neurons) | |||
| return nc | |||
| def get_effective_nc(self): | |||
| """ | |||
| Get the metric of 'effective neuron coverage'. | |||
| Returns: | |||
| float, the metric of 'the effective neuron coverage'. | |||
| Examples: | |||
| >>> model_fuzz_test.get_effective_nc() | |||
| """ | |||
| covered_neurons = len([v for v in self._effective_model_layer_dict.values() if v]) | |||
| total_neurons = len(self._effective_model_layer_dict) | |||
| effective_nc = covered_neurons / float(total_neurons) | |||
| return effective_nc | |||
| def get_kmnc(self): | |||
| """ | |||
| Get the metric of 'k-multisection neuron coverage'. KMNC measures how | |||
| @@ -21,6 +21,7 @@ from mindspore import nn | |||
| from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||
| from mindspore.train import Model | |||
| from mindspore import context | |||
| from mindspore.ops import TensorSummary | |||
| from mindarmour.adv_robustness.attacks import FastGradientSignMethod | |||
| from mindarmour.utils.logger import LogUtil | |||
| @@ -46,6 +47,7 @@ class Net(Cell): | |||
| """ | |||
| super(Net, self).__init__() | |||
| self._relu = nn.ReLU() | |||
| self.summary = TensorSummary() | |||
| def construct(self, inputs): | |||
| """ | |||
| @@ -54,7 +56,10 @@ class Net(Cell): | |||
| Args: | |||
| inputs (Tensor): Input data. | |||
| """ | |||
| self.summary('input', inputs) | |||
| out = self._relu(inputs) | |||
| self.summary('1', out) | |||
| return out | |||
| @@ -71,7 +76,10 @@ def test_lenet_mnist_coverage_cpu(): | |||
| # initialize fuzz test with training dataset | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| top_k = 3 | |||
| threshold = 0.1 | |||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | |||
| # fuzz test with original test data | |||
| @@ -83,6 +91,10 @@ def test_lenet_mnist_coverage_cpu(): | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| # generate adv_data | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True) | |||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | |||
| @@ -92,6 +104,9 @@ def test_lenet_mnist_coverage_cpu(): | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @@ -107,9 +122,10 @@ def test_lenet_mnist_coverage_ascend(): | |||
| # initialize fuzz test with training dataset | |||
| neuron_num = 10 | |||
| segmented_num = 1000 | |||
| top_k = 3 | |||
| threshold = 0.1 | |||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data,) | |||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| @@ -121,6 +137,10 @@ def test_lenet_mnist_coverage_ascend(): | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||
| # generate adv_data | |||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
| adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||
| @@ -128,3 +148,7 @@ def test_lenet_mnist_coverage_ascend(): | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
| model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||