Merge pull request !254 from RyanZ/test_coverage3tags/v1.5.0-rc1
| @@ -1,24 +1,32 @@ | |||||
| # Application demos of model fuzzing | # Application demos of model fuzzing | ||||
| ## Introduction | ## Introduction | ||||
| The same as the traditional software fuzz testing, we can also design fuzz test for AI models. Compared to | The same as the traditional software fuzz testing, we can also design fuzz test for AI models. Compared to | ||||
| branch coverage or line coverage of traditional software, some people propose the | branch coverage or line coverage of traditional software, some people propose the | ||||
| concept of 'neuron coverage' based on the unique structure of deep neural network. We can use the neuron coverage | concept of 'neuron coverage' based on the unique structure of deep neural network. We can use the neuron coverage | ||||
| as a guide to search more metamorphic inputs to test our models. | as a guide to search more metamorphic inputs to test our models. | ||||
| ## 1. calculation of neuron coverage | |||||
| There are three metrics proposed for evaluating the neuron coverage of a test:KMNC, NBC and SNAC. Usually we need to | |||||
| feed all the training dataset into the model first, and record the output range of all neurons (however, only the last | |||||
| layer of neurons are recorded in our method). In the testing phase, we feed test samples into the model, and | |||||
| calculate those three metrics mentioned above according to those neurons' output distribution. | |||||
| ## 1. calculation of neuron coverage | |||||
| There are five metrics proposed for evaluating the neuron coverage of a test:NC, Effective NC, KMNC, NBC and SNAC. | |||||
| Usually we need to feed all the training dataset into the model first, and record the output range of all neurons | |||||
| (however, in KMNC, NBC and SNAC, only the last layer of neurons are recorded in our method). In the testing phase, | |||||
| we feed test samples into the model, and calculate those three metrics mentioned above according to those neurons' | |||||
| output distribution. | |||||
| ```sh | ```sh | ||||
| $ cd examples/ai_fuzzer/ | |||||
| $ python lenet5_mnist_coverage.py | |||||
| cd examples/ai_fuzzer/ | |||||
| python lenet5_mnist_coverage.py | |||||
| ``` | ``` | ||||
| ## 2. fuzz test for AI model | |||||
| ## 2. fuzz test for AI model | |||||
| We have provided several types of methods for manipulating metamorphic inputs: affine transformation, pixel | We have provided several types of methods for manipulating metamorphic inputs: affine transformation, pixel | ||||
| transformation and adversarial attacks. Usually we feed the original samples into the fuzz function as seeds, and | transformation and adversarial attacks. Usually we feed the original samples into the fuzz function as seeds, and | ||||
| then metamorphic samples are generated through iterative manipulations. | then metamorphic samples are generated through iterative manipulations. | ||||
| ```sh | ```sh | ||||
| $ cd examples/ai_fuzzer/ | |||||
| $ python lenet5_mnist_fuzzing.py | |||||
| cd examples/ai_fuzzer/ | |||||
| python lenet5_mnist_fuzzing.py | |||||
| ``` | ``` | ||||
| @@ -31,7 +31,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics | |||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| from examples.common.dataset.data_processing import generate_mnist_dataset | from examples.common.dataset.data_processing import generate_mnist_dataset | ||||
| from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
| from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||||
| LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
| TAG = 'Fuzz_testing and enhance model' | TAG = 'Fuzz_testing and enhance model' | ||||
| @@ -75,9 +75,11 @@ def example_lenet_mnist_fuzzing(): | |||||
| images = data[0].astype(np.float32) | images = data[0].astype(np.float32) | ||||
| train_images.append(images) | train_images.append(images) | ||||
| train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
| neuron_num = 10 | |||||
| segmented_num = 1000 | |||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
| model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| # get test data | # get test data | ||||
| @@ -22,7 +22,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics | |||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| from examples.common.dataset.data_processing import generate_mnist_dataset | from examples.common.dataset.data_processing import generate_mnist_dataset | ||||
| from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
| from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||||
| LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
| TAG = 'Neuron coverage test' | TAG = 'Neuron coverage test' | ||||
| @@ -46,9 +46,13 @@ def test_lenet_mnist_coverage(): | |||||
| images = data[0].astype(np.float32) | images = data[0].astype(np.float32) | ||||
| train_images.append(images) | train_images.append(images) | ||||
| train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
| neuron_num = 10 | |||||
| segmented_num = 1000 | |||||
| top_k = 3 | |||||
| threshold = 0.1 | |||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| # get test data | # get test data | ||||
| @@ -69,6 +73,10 @@ def test_lenet_mnist_coverage(): | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
| model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
| # generate adv_data | # generate adv_data | ||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True) | loss = SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | ||||
| @@ -78,6 +86,10 @@ def test_lenet_mnist_coverage(): | |||||
| LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc()) | ||||
| LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac()) | ||||
| model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| # device_target can be "CPU", "GPU" or "Ascend" | # device_target can be "CPU", "GPU" or "Ascend" | ||||
| @@ -21,7 +21,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics | |||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| from examples.common.dataset.data_processing import generate_mnist_dataset | from examples.common.dataset.data_processing import generate_mnist_dataset | ||||
| from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
| from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||||
| LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
| TAG = 'Fuzz_test' | TAG = 'Fuzz_test' | ||||
| @@ -0,0 +1,99 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| lenet network with summary | |||||
| """ | |||||
| from mindspore import nn | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from mindspore.ops import TensorSummary | |||||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
| weight = weight_variable() | |||||
| return nn.Conv2d(in_channels, out_channels, | |||||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||||
| def fc_with_initialize(input_channels, out_channels): | |||||
| weight = weight_variable() | |||||
| bias = weight_variable() | |||||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||||
| def weight_variable(): | |||||
| return TruncatedNormal(0.05) | |||||
| class LeNet5(nn.Cell): | |||||
| """ | |||||
| Lenet network | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LeNet5, self).__init__() | |||||
| self.conv1 = conv(1, 6, 5) | |||||
| self.conv2 = conv(6, 16, 5) | |||||
| self.fc1 = fc_with_initialize(16*5*5, 120) | |||||
| self.fc2 = fc_with_initialize(120, 84) | |||||
| self.fc3 = fc_with_initialize(84, 10) | |||||
| self.relu = nn.ReLU() | |||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
| self.flatten = nn.Flatten() | |||||
| self.summary = TensorSummary() | |||||
| def construct(self, x): | |||||
| """ | |||||
| construct the network architecture | |||||
| Returns: | |||||
| x (tensor): network output | |||||
| """ | |||||
| self.summary('input', x) | |||||
| x = self.conv1(x) | |||||
| self.summary('1', x) | |||||
| x = self.relu(x) | |||||
| self.summary('2', x) | |||||
| x = self.max_pool2d(x) | |||||
| self.summary('3', x) | |||||
| x = self.conv2(x) | |||||
| self.summary('4', x) | |||||
| x = self.relu(x) | |||||
| self.summary('5', x) | |||||
| x = self.max_pool2d(x) | |||||
| self.summary('6', x) | |||||
| x = self.flatten(x) | |||||
| self.summary('7', x) | |||||
| x = self.fc1(x) | |||||
| self.summary('8', x) | |||||
| x = self.relu(x) | |||||
| self.summary('9', x) | |||||
| x = self.fc2(x) | |||||
| self.summary('10', x) | |||||
| x = self.relu(x) | |||||
| self.summary('11', x) | |||||
| x = self.fc3(x) | |||||
| self.summary('output', x) | |||||
| return x | |||||
| @@ -15,10 +15,12 @@ | |||||
| Model-Test Coverage Metrics. | Model-Test Coverage Metrics. | ||||
| """ | """ | ||||
| from collections import defaultdict | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import Model | from mindspore import Model | ||||
| from mindspore.train.summary.summary_record import _get_summary_tensor_data | |||||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, \ | ||||
| check_int_positive, check_param_multi_types | check_int_positive, check_param_multi_types | ||||
| @@ -63,6 +65,9 @@ class ModelCoverageMetrics: | |||||
| >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
| >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
| >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
| >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||||
| >>> print('NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
| >>> print('Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
| """ | """ | ||||
| def __init__(self, model, neuron_num, segmented_num, train_dataset): | def __init__(self, model, neuron_num, segmented_num, train_dataset): | ||||
| @@ -81,6 +86,24 @@ class ModelCoverageMetrics: | |||||
| self._lower_corner_hits = [0]*self._neuron_num | self._lower_corner_hits = [0]*self._neuron_num | ||||
| self._upper_corner_hits = [0]*self._neuron_num | self._upper_corner_hits = [0]*self._neuron_num | ||||
| self._bounds_get(train_dataset) | self._bounds_get(train_dataset) | ||||
| self._model_layer_dict = defaultdict(bool) | |||||
| self._effective_model_layer_dict = defaultdict(bool) | |||||
| def _set_init_effective_coverage_table(self, dataset): | |||||
| """ | |||||
| Initialise the coverage table of each neuron in the model. | |||||
| Args: | |||||
| dataset (numpy.ndarray): Dataset used for initialising the coverage table. | |||||
| """ | |||||
| self._model.predict(Tensor(dataset[0:1])) | |||||
| tensors = _get_summary_tensor_data() | |||||
| for name, tensor in tensors.items(): | |||||
| if 'input' in name: | |||||
| continue | |||||
| for num_neuron in range(tensor.shape[1]): | |||||
| self._model_layer_dict[(name, num_neuron)] = False | |||||
| self._effective_model_layer_dict[(name, num_neuron)] = False | |||||
| def _bounds_get(self, train_dataset, batch_size=32): | def _bounds_get(self, train_dataset, batch_size=32): | ||||
| """ | """ | ||||
| @@ -130,6 +153,27 @@ class ModelCoverageMetrics: | |||||
| else: | else: | ||||
| self._main_section_hits[i][int(section_indexes[i])] = 1 | self._main_section_hits[i][int(section_indexes[i])] = 1 | ||||
| def _coverage_update(self, name, tensor, scaled_mean, scaled_rank, top_k, threshold): | |||||
| """ | |||||
| Update the coverage matrix of neural coverage and effective neural coverage. | |||||
| Args: | |||||
| name (string): the name of the tensor. | |||||
| tensor (tensor): the tensor in the network. | |||||
| scaled_mean (numpy.ndarray): feature map of the tensor. | |||||
| scaled_rank (numpy.ndarray): rank of tensor value. | |||||
| top_k (int): neuron is covered when its output has the top k largest value in that hidden layer. | |||||
| threshold (float): neuron is covered when its output is greater than the threshold. | |||||
| """ | |||||
| for num_neuron in range(tensor.shape[1]): | |||||
| if num_neuron >= (len(scaled_rank) - top_k) and not \ | |||||
| self._effective_model_layer_dict[(name, scaled_rank[num_neuron])]: | |||||
| self._effective_model_layer_dict[(name, scaled_rank[num_neuron])] = True | |||||
| if scaled_mean[num_neuron] > threshold and not \ | |||||
| self._model_layer_dict[(name, num_neuron)]: | |||||
| self._model_layer_dict[(name, num_neuron)] = True | |||||
| def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): | def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): | ||||
| """ | """ | ||||
| Calculate the testing adequacy of the given dataset. | Calculate the testing adequacy of the given dataset. | ||||
| @@ -143,8 +187,9 @@ class ModelCoverageMetrics: | |||||
| Examples: | Examples: | ||||
| >>> neuron_num = 10 | >>> neuron_num = 10 | ||||
| >>> segmented_num = 1000 | >>> segmented_num = 1000 | ||||
| >>> batch_size = 32 | |||||
| >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | ||||
| >>> model_fuzz_test.calculate_coverage(test_images) | |||||
| >>> model_fuzz_test.calculate_coverage(test_images, top_k, threshold, batch_size) | |||||
| """ | """ | ||||
| dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
| @@ -157,6 +202,79 @@ class ModelCoverageMetrics: | |||||
| for i in range(batches): | for i in range(batches): | ||||
| self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) | self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) | ||||
| def calculate_effective_coverage(self, dataset, top_k=3, threshold=0.1, batch_size=32): | |||||
| """ | |||||
| Calculate the effective testing adequacy of the given dataset. | |||||
| In effective neural coverage, neuron is covered when its output has the top k largest value | |||||
| in that hidden layers. In neural coverage, neuron is covered when its output is greater than the | |||||
| threshold. Coverage equals the covered neurons divided by the total neurons in the network. | |||||
| Args: | |||||
| threshold (float): neuron is covered when its output is greater than the threshold. | |||||
| top_k (int): neuron is covered when its output has the top k largest value in that hiddern layer. | |||||
| dataset (numpy.ndarray): Data for fuzz test. | |||||
| Examples: | |||||
| >>> neuron_num = 10 | |||||
| >>> segmented_num = 1000 | |||||
| >>> top_k = 3 | |||||
| >>> threshold = 0.1 | |||||
| >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||||
| >>> model_fuzz_test.calculate_coverage(test_images) | |||||
| >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||||
| """ | |||||
| top_k = check_int_positive('top_k', top_k) | |||||
| dataset = check_numpy_param('dataset', dataset) | |||||
| batch_size = check_int_positive('batch_size', batch_size) | |||||
| batches = dataset.shape[0] // batch_size | |||||
| self._set_init_effective_coverage_table(dataset) | |||||
| for i in range(batches): | |||||
| inputs = dataset[i*batch_size: (i + 1)*batch_size] | |||||
| self._model.predict(Tensor(inputs)).asnumpy() | |||||
| tensors = _get_summary_tensor_data() | |||||
| for name, tensor in tensors.items(): | |||||
| if 'input' in name: | |||||
| continue | |||||
| scaled = tensor.asnumpy()[-1] | |||||
| if scaled.ndim >= 3: # | |||||
| scaled_mean = np.mean(scaled, axis=(1, 2)) | |||||
| scaled_rank = np.argsort(scaled_mean) | |||||
| self._coverage_update(name, tensor, scaled_mean, scaled_rank, top_k, threshold) | |||||
| else: | |||||
| scaled_rank = np.argsort(scaled) | |||||
| self._coverage_update(name, tensor, scaled, scaled_rank, top_k, threshold) | |||||
| def get_nc(self): | |||||
| """ | |||||
| Get the metric of 'neuron coverage'. | |||||
| Returns: | |||||
| float, the metric of 'neuron coverage'. | |||||
| Examples: | |||||
| >>> model_fuzz_test.get_nc() | |||||
| """ | |||||
| covered_neurons = len([v for v in self._model_layer_dict.values() if v]) | |||||
| total_neurons = len(self._model_layer_dict) | |||||
| nc = covered_neurons / float(total_neurons) | |||||
| return nc | |||||
| def get_effective_nc(self): | |||||
| """ | |||||
| Get the metric of 'effective neuron coverage'. | |||||
| Returns: | |||||
| float, the metric of 'the effective neuron coverage'. | |||||
| Examples: | |||||
| >>> model_fuzz_test.get_effective_nc() | |||||
| """ | |||||
| covered_neurons = len([v for v in self._effective_model_layer_dict.values() if v]) | |||||
| total_neurons = len(self._effective_model_layer_dict) | |||||
| effective_nc = covered_neurons / float(total_neurons) | |||||
| return effective_nc | |||||
| def get_kmnc(self): | def get_kmnc(self): | ||||
| """ | """ | ||||
| Get the metric of 'k-multisection neuron coverage'. KMNC measures how | Get the metric of 'k-multisection neuron coverage'. KMNC measures how | ||||
| @@ -21,6 +21,7 @@ from mindspore import nn | |||||
| from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | ||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.ops import TensorSummary | |||||
| from mindarmour.adv_robustness.attacks import FastGradientSignMethod | from mindarmour.adv_robustness.attacks import FastGradientSignMethod | ||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| @@ -46,6 +47,7 @@ class Net(Cell): | |||||
| """ | """ | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| self._relu = nn.ReLU() | self._relu = nn.ReLU() | ||||
| self.summary = TensorSummary() | |||||
| def construct(self, inputs): | def construct(self, inputs): | ||||
| """ | """ | ||||
| @@ -54,7 +56,10 @@ class Net(Cell): | |||||
| Args: | Args: | ||||
| inputs (Tensor): Input data. | inputs (Tensor): Input data. | ||||
| """ | """ | ||||
| self.summary('input', inputs) | |||||
| out = self._relu(inputs) | out = self._relu(inputs) | ||||
| self.summary('1', out) | |||||
| return out | return out | ||||
| @@ -71,7 +76,10 @@ def test_lenet_mnist_coverage_cpu(): | |||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| neuron_num = 10 | neuron_num = 10 | ||||
| segmented_num = 1000 | segmented_num = 1000 | ||||
| top_k = 3 | |||||
| threshold = 0.1 | |||||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | training_data = (np.random.random((10000, 10))*20).astype(np.float32) | ||||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | ||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| @@ -83,6 +91,10 @@ def test_lenet_mnist_coverage_cpu(): | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
| model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) | |||||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
| # generate adv_data | # generate adv_data | ||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True) | loss = SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | ||||
| @@ -92,6 +104,9 @@ def test_lenet_mnist_coverage_cpu(): | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
| model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @@ -107,9 +122,10 @@ def test_lenet_mnist_coverage_ascend(): | |||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| neuron_num = 10 | neuron_num = 10 | ||||
| segmented_num = 1000 | segmented_num = 1000 | ||||
| top_k = 3 | |||||
| threshold = 0.1 | |||||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | training_data = (np.random.random((10000, 10))*20).astype(np.float32) | ||||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data,) | |||||
| model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | |||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| # get test data | # get test data | ||||
| @@ -121,6 +137,10 @@ def test_lenet_mnist_coverage_ascend(): | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
| model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) | |||||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
| # generate adv_data | # generate adv_data | ||||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) | attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) | ||||
| adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | ||||
| @@ -128,3 +148,7 @@ def test_lenet_mnist_coverage_ascend(): | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
| model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||||
| LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
| LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||