Merge pull request !254 from RyanZ/test_coverage3tags/v1.5.0-rc1
@@ -1,24 +1,32 @@ | |||||
# Application demos of model fuzzing | # Application demos of model fuzzing | ||||
## Introduction | ## Introduction | ||||
The same as the traditional software fuzz testing, we can also design fuzz test for AI models. Compared to | The same as the traditional software fuzz testing, we can also design fuzz test for AI models. Compared to | ||||
branch coverage or line coverage of traditional software, some people propose the | branch coverage or line coverage of traditional software, some people propose the | ||||
concept of 'neuron coverage' based on the unique structure of deep neural network. We can use the neuron coverage | concept of 'neuron coverage' based on the unique structure of deep neural network. We can use the neuron coverage | ||||
as a guide to search more metamorphic inputs to test our models. | as a guide to search more metamorphic inputs to test our models. | ||||
## 1. calculation of neuron coverage | |||||
There are three metrics proposed for evaluating the neuron coverage of a test:KMNC, NBC and SNAC. Usually we need to | |||||
feed all the training dataset into the model first, and record the output range of all neurons (however, only the last | |||||
layer of neurons are recorded in our method). In the testing phase, we feed test samples into the model, and | |||||
calculate those three metrics mentioned above according to those neurons' output distribution. | |||||
## 1. calculation of neuron coverage | |||||
There are five metrics proposed for evaluating the neuron coverage of a test:NC, Effective NC, KMNC, NBC and SNAC. | |||||
Usually we need to feed all the training dataset into the model first, and record the output range of all neurons | |||||
(however, in KMNC, NBC and SNAC, only the last layer of neurons are recorded in our method). In the testing phase, | |||||
we feed test samples into the model, and calculate those three metrics mentioned above according to those neurons' | |||||
output distribution. | |||||
```sh | ```sh | ||||
$ cd examples/ai_fuzzer/ | |||||
$ python lenet5_mnist_coverage.py | |||||
cd examples/ai_fuzzer/ | |||||
python lenet5_mnist_coverage.py | |||||
``` | ``` | ||||
## 2. fuzz test for AI model | |||||
## 2. fuzz test for AI model | |||||
We have provided several types of methods for manipulating metamorphic inputs: affine transformation, pixel | We have provided several types of methods for manipulating metamorphic inputs: affine transformation, pixel | ||||
transformation and adversarial attacks. Usually we feed the original samples into the fuzz function as seeds, and | transformation and adversarial attacks. Usually we feed the original samples into the fuzz function as seeds, and | ||||
then metamorphic samples are generated through iterative manipulations. | then metamorphic samples are generated through iterative manipulations. | ||||
```sh | ```sh | ||||
$ cd examples/ai_fuzzer/ | |||||
$ python lenet5_mnist_fuzzing.py | |||||
cd examples/ai_fuzzer/ | |||||
python lenet5_mnist_fuzzing.py | |||||
``` | ``` |
@@ -31,7 +31,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from examples.common.dataset.data_processing import generate_mnist_dataset | from examples.common.dataset.data_processing import generate_mnist_dataset | ||||
from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
TAG = 'Fuzz_testing and enhance model' | TAG = 'Fuzz_testing and enhance model' | ||||
@@ -75,9 +75,11 @@ def example_lenet_mnist_fuzzing(): | |||||
images = data[0].astype(np.float32) | images = data[0].astype(np.float32) | ||||
train_images.append(images) | train_images.append(images) | ||||
train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
neuron_num = 10 | |||||
segmented_num = 1000 | |||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
@@ -22,7 +22,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from examples.common.dataset.data_processing import generate_mnist_dataset | from examples.common.dataset.data_processing import generate_mnist_dataset | ||||
from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
TAG = 'Neuron coverage test' | TAG = 'Neuron coverage test' | ||||
@@ -46,9 +46,13 @@ def test_lenet_mnist_coverage(): | |||||
images = data[0].astype(np.float32) | images = data[0].astype(np.float32) | ||||
train_images.append(images) | train_images.append(images) | ||||
train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
neuron_num = 10 | |||||
segmented_num = 1000 | |||||
top_k = 3 | |||||
threshold = 0.1 | |||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
@@ -69,6 +73,10 @@ def test_lenet_mnist_coverage(): | |||||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||||
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
# generate adv_data | # generate adv_data | ||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True) | loss = SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | ||||
@@ -78,6 +86,10 @@ def test_lenet_mnist_coverage(): | |||||
LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac()) | ||||
model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||||
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
# device_target can be "CPU", "GPU" or "Ascend" | # device_target can be "CPU", "GPU" or "Ascend" | ||||
@@ -21,7 +21,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from examples.common.dataset.data_processing import generate_mnist_dataset | from examples.common.dataset.data_processing import generate_mnist_dataset | ||||
from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
TAG = 'Fuzz_test' | TAG = 'Fuzz_test' | ||||
@@ -0,0 +1,99 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
lenet network with summary | |||||
""" | |||||
from mindspore import nn | |||||
from mindspore.common.initializer import TruncatedNormal | |||||
from mindspore.ops import TensorSummary | |||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
weight = weight_variable() | |||||
return nn.Conv2d(in_channels, out_channels, | |||||
kernel_size=kernel_size, stride=stride, padding=padding, | |||||
weight_init=weight, has_bias=False, pad_mode="valid") | |||||
def fc_with_initialize(input_channels, out_channels): | |||||
weight = weight_variable() | |||||
bias = weight_variable() | |||||
return nn.Dense(input_channels, out_channels, weight, bias) | |||||
def weight_variable(): | |||||
return TruncatedNormal(0.05) | |||||
class LeNet5(nn.Cell): | |||||
""" | |||||
Lenet network | |||||
""" | |||||
def __init__(self): | |||||
super(LeNet5, self).__init__() | |||||
self.conv1 = conv(1, 6, 5) | |||||
self.conv2 = conv(6, 16, 5) | |||||
self.fc1 = fc_with_initialize(16*5*5, 120) | |||||
self.fc2 = fc_with_initialize(120, 84) | |||||
self.fc3 = fc_with_initialize(84, 10) | |||||
self.relu = nn.ReLU() | |||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
self.flatten = nn.Flatten() | |||||
self.summary = TensorSummary() | |||||
def construct(self, x): | |||||
""" | |||||
construct the network architecture | |||||
Returns: | |||||
x (tensor): network output | |||||
""" | |||||
self.summary('input', x) | |||||
x = self.conv1(x) | |||||
self.summary('1', x) | |||||
x = self.relu(x) | |||||
self.summary('2', x) | |||||
x = self.max_pool2d(x) | |||||
self.summary('3', x) | |||||
x = self.conv2(x) | |||||
self.summary('4', x) | |||||
x = self.relu(x) | |||||
self.summary('5', x) | |||||
x = self.max_pool2d(x) | |||||
self.summary('6', x) | |||||
x = self.flatten(x) | |||||
self.summary('7', x) | |||||
x = self.fc1(x) | |||||
self.summary('8', x) | |||||
x = self.relu(x) | |||||
self.summary('9', x) | |||||
x = self.fc2(x) | |||||
self.summary('10', x) | |||||
x = self.relu(x) | |||||
self.summary('11', x) | |||||
x = self.fc3(x) | |||||
self.summary('output', x) | |||||
return x |
@@ -15,10 +15,12 @@ | |||||
Model-Test Coverage Metrics. | Model-Test Coverage Metrics. | ||||
""" | """ | ||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindspore import Model | from mindspore import Model | ||||
from mindspore.train.summary.summary_record import _get_summary_tensor_data | |||||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, \ | ||||
check_int_positive, check_param_multi_types | check_int_positive, check_param_multi_types | ||||
@@ -63,6 +65,9 @@ class ModelCoverageMetrics: | |||||
>>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
>>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
>>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
>>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||||
>>> print('NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
>>> print('Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
""" | """ | ||||
def __init__(self, model, neuron_num, segmented_num, train_dataset): | def __init__(self, model, neuron_num, segmented_num, train_dataset): | ||||
@@ -81,6 +86,24 @@ class ModelCoverageMetrics: | |||||
self._lower_corner_hits = [0]*self._neuron_num | self._lower_corner_hits = [0]*self._neuron_num | ||||
self._upper_corner_hits = [0]*self._neuron_num | self._upper_corner_hits = [0]*self._neuron_num | ||||
self._bounds_get(train_dataset) | self._bounds_get(train_dataset) | ||||
self._model_layer_dict = defaultdict(bool) | |||||
self._effective_model_layer_dict = defaultdict(bool) | |||||
def _set_init_effective_coverage_table(self, dataset): | |||||
""" | |||||
Initialise the coverage table of each neuron in the model. | |||||
Args: | |||||
dataset (numpy.ndarray): Dataset used for initialising the coverage table. | |||||
""" | |||||
self._model.predict(Tensor(dataset[0:1])) | |||||
tensors = _get_summary_tensor_data() | |||||
for name, tensor in tensors.items(): | |||||
if 'input' in name: | |||||
continue | |||||
for num_neuron in range(tensor.shape[1]): | |||||
self._model_layer_dict[(name, num_neuron)] = False | |||||
self._effective_model_layer_dict[(name, num_neuron)] = False | |||||
def _bounds_get(self, train_dataset, batch_size=32): | def _bounds_get(self, train_dataset, batch_size=32): | ||||
""" | """ | ||||
@@ -130,6 +153,27 @@ class ModelCoverageMetrics: | |||||
else: | else: | ||||
self._main_section_hits[i][int(section_indexes[i])] = 1 | self._main_section_hits[i][int(section_indexes[i])] = 1 | ||||
def _coverage_update(self, name, tensor, scaled_mean, scaled_rank, top_k, threshold): | |||||
""" | |||||
Update the coverage matrix of neural coverage and effective neural coverage. | |||||
Args: | |||||
name (string): the name of the tensor. | |||||
tensor (tensor): the tensor in the network. | |||||
scaled_mean (numpy.ndarray): feature map of the tensor. | |||||
scaled_rank (numpy.ndarray): rank of tensor value. | |||||
top_k (int): neuron is covered when its output has the top k largest value in that hidden layer. | |||||
threshold (float): neuron is covered when its output is greater than the threshold. | |||||
""" | |||||
for num_neuron in range(tensor.shape[1]): | |||||
if num_neuron >= (len(scaled_rank) - top_k) and not \ | |||||
self._effective_model_layer_dict[(name, scaled_rank[num_neuron])]: | |||||
self._effective_model_layer_dict[(name, scaled_rank[num_neuron])] = True | |||||
if scaled_mean[num_neuron] > threshold and not \ | |||||
self._model_layer_dict[(name, num_neuron)]: | |||||
self._model_layer_dict[(name, num_neuron)] = True | |||||
def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): | def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): | ||||
""" | """ | ||||
Calculate the testing adequacy of the given dataset. | Calculate the testing adequacy of the given dataset. | ||||
@@ -143,8 +187,9 @@ class ModelCoverageMetrics: | |||||
Examples: | Examples: | ||||
>>> neuron_num = 10 | >>> neuron_num = 10 | ||||
>>> segmented_num = 1000 | >>> segmented_num = 1000 | ||||
>>> batch_size = 32 | |||||
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | ||||
>>> model_fuzz_test.calculate_coverage(test_images) | |||||
>>> model_fuzz_test.calculate_coverage(test_images, top_k, threshold, batch_size) | |||||
""" | """ | ||||
dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
@@ -157,6 +202,79 @@ class ModelCoverageMetrics: | |||||
for i in range(batches): | for i in range(batches): | ||||
self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) | self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) | ||||
def calculate_effective_coverage(self, dataset, top_k=3, threshold=0.1, batch_size=32): | |||||
""" | |||||
Calculate the effective testing adequacy of the given dataset. | |||||
In effective neural coverage, neuron is covered when its output has the top k largest value | |||||
in that hidden layers. In neural coverage, neuron is covered when its output is greater than the | |||||
threshold. Coverage equals the covered neurons divided by the total neurons in the network. | |||||
Args: | |||||
threshold (float): neuron is covered when its output is greater than the threshold. | |||||
top_k (int): neuron is covered when its output has the top k largest value in that hiddern layer. | |||||
dataset (numpy.ndarray): Data for fuzz test. | |||||
Examples: | |||||
>>> neuron_num = 10 | |||||
>>> segmented_num = 1000 | |||||
>>> top_k = 3 | |||||
>>> threshold = 0.1 | |||||
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||||
>>> model_fuzz_test.calculate_coverage(test_images) | |||||
>>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) | |||||
""" | |||||
top_k = check_int_positive('top_k', top_k) | |||||
dataset = check_numpy_param('dataset', dataset) | |||||
batch_size = check_int_positive('batch_size', batch_size) | |||||
batches = dataset.shape[0] // batch_size | |||||
self._set_init_effective_coverage_table(dataset) | |||||
for i in range(batches): | |||||
inputs = dataset[i*batch_size: (i + 1)*batch_size] | |||||
self._model.predict(Tensor(inputs)).asnumpy() | |||||
tensors = _get_summary_tensor_data() | |||||
for name, tensor in tensors.items(): | |||||
if 'input' in name: | |||||
continue | |||||
scaled = tensor.asnumpy()[-1] | |||||
if scaled.ndim >= 3: # | |||||
scaled_mean = np.mean(scaled, axis=(1, 2)) | |||||
scaled_rank = np.argsort(scaled_mean) | |||||
self._coverage_update(name, tensor, scaled_mean, scaled_rank, top_k, threshold) | |||||
else: | |||||
scaled_rank = np.argsort(scaled) | |||||
self._coverage_update(name, tensor, scaled, scaled_rank, top_k, threshold) | |||||
def get_nc(self): | |||||
""" | |||||
Get the metric of 'neuron coverage'. | |||||
Returns: | |||||
float, the metric of 'neuron coverage'. | |||||
Examples: | |||||
>>> model_fuzz_test.get_nc() | |||||
""" | |||||
covered_neurons = len([v for v in self._model_layer_dict.values() if v]) | |||||
total_neurons = len(self._model_layer_dict) | |||||
nc = covered_neurons / float(total_neurons) | |||||
return nc | |||||
def get_effective_nc(self): | |||||
""" | |||||
Get the metric of 'effective neuron coverage'. | |||||
Returns: | |||||
float, the metric of 'the effective neuron coverage'. | |||||
Examples: | |||||
>>> model_fuzz_test.get_effective_nc() | |||||
""" | |||||
covered_neurons = len([v for v in self._effective_model_layer_dict.values() if v]) | |||||
total_neurons = len(self._effective_model_layer_dict) | |||||
effective_nc = covered_neurons / float(total_neurons) | |||||
return effective_nc | |||||
def get_kmnc(self): | def get_kmnc(self): | ||||
""" | """ | ||||
Get the metric of 'k-multisection neuron coverage'. KMNC measures how | Get the metric of 'k-multisection neuron coverage'. KMNC measures how | ||||
@@ -21,6 +21,7 @@ from mindspore import nn | |||||
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | ||||
from mindspore.train import Model | from mindspore.train import Model | ||||
from mindspore import context | from mindspore import context | ||||
from mindspore.ops import TensorSummary | |||||
from mindarmour.adv_robustness.attacks import FastGradientSignMethod | from mindarmour.adv_robustness.attacks import FastGradientSignMethod | ||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
@@ -46,6 +47,7 @@ class Net(Cell): | |||||
""" | """ | ||||
super(Net, self).__init__() | super(Net, self).__init__() | ||||
self._relu = nn.ReLU() | self._relu = nn.ReLU() | ||||
self.summary = TensorSummary() | |||||
def construct(self, inputs): | def construct(self, inputs): | ||||
""" | """ | ||||
@@ -54,7 +56,10 @@ class Net(Cell): | |||||
Args: | Args: | ||||
inputs (Tensor): Input data. | inputs (Tensor): Input data. | ||||
""" | """ | ||||
self.summary('input', inputs) | |||||
out = self._relu(inputs) | out = self._relu(inputs) | ||||
self.summary('1', out) | |||||
return out | return out | ||||
@@ -71,7 +76,10 @@ def test_lenet_mnist_coverage_cpu(): | |||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
neuron_num = 10 | neuron_num = 10 | ||||
segmented_num = 1000 | segmented_num = 1000 | ||||
top_k = 3 | |||||
threshold = 0.1 | |||||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | training_data = (np.random.random((10000, 10))*20).astype(np.float32) | ||||
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | ||||
# fuzz test with original test data | # fuzz test with original test data | ||||
@@ -83,6 +91,10 @@ def test_lenet_mnist_coverage_cpu(): | |||||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) | |||||
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
# generate adv_data | # generate adv_data | ||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True) | loss = SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | ||||
@@ -92,6 +104,9 @@ def test_lenet_mnist_coverage_cpu(): | |||||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||||
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
@@ -107,9 +122,10 @@ def test_lenet_mnist_coverage_ascend(): | |||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
neuron_num = 10 | neuron_num = 10 | ||||
segmented_num = 1000 | segmented_num = 1000 | ||||
top_k = 3 | |||||
threshold = 0.1 | |||||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | training_data = (np.random.random((10000, 10))*20).astype(np.float32) | ||||
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data,) | |||||
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
@@ -121,6 +137,10 @@ def test_lenet_mnist_coverage_ascend(): | |||||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) | |||||
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) | |||||
# generate adv_data | # generate adv_data | ||||
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) | attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) | ||||
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | ||||
@@ -128,3 +148,7 @@ def test_lenet_mnist_coverage_ascend(): | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) | |||||
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) | |||||
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) |