Merge pull request !12 from jxlang910/mastertags/v0.2.0-alpha
@@ -0,0 +1,89 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import sys | |||
import numpy as np | |||
from mindspore import Model | |||
from mindspore import context | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||
from mindarmour.attacks.gradient_method import FastGradientSignMethod | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
from lenet5_net import LeNet5 | |||
sys.path.append("..") | |||
from data_processing import generate_mnist_dataset | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Neuron coverage test' | |||
LOGGER.set_level('INFO') | |||
def test_lenet_mnist_coverage(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
# upload trained network | |||
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = LeNet5() | |||
load_dict = load_checkpoint(ckpt_name) | |||
load_param_into_net(net, load_dict) | |||
model = Model(net) | |||
# get training data | |||
data_list = "./MNIST_unzip/train" | |||
batch_size = 32 | |||
ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
train_images = [] | |||
for data in ds.create_tuple_iterator(): | |||
images = data[0].astype(np.float32) | |||
train_images.append(images) | |||
train_images = np.concatenate(train_images, axis=0) | |||
# initialize fuzz test with training dataset | |||
model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||
# fuzz test with original test data | |||
# get test data | |||
data_list = "./MNIST_unzip/test" | |||
batch_size = 32 | |||
ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
test_images = [] | |||
test_labels = [] | |||
for data in ds.create_tuple_iterator(): | |||
images = data[0].astype(np.float32) | |||
labels = data[1] | |||
test_images.append(images) | |||
test_labels.append(labels) | |||
test_images = np.concatenate(test_images, axis=0) | |||
test_labels = np.concatenate(test_labels, axis=0) | |||
model_fuzz_test.test_adequacy_coverage_calculate(test_images) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
# generate adv_data | |||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | |||
adv_data = attack.batch_generate(test_images, test_labels, batch_size=32) | |||
model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||
bias_coefficient=0.5) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
if __name__ == '__main__': | |||
test_lenet_mnist_coverage() |
@@ -0,0 +1,3 @@ | |||
from .model_coverage_metrics import ModelCoverageMetrics | |||
__all__ = ['ModelCoverageMetrics'] |
@@ -0,0 +1,167 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Model-Test Coverage Metrics. | |||
""" | |||
import numpy as np | |||
from mindspore import Tensor | |||
from mindspore import Model | |||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
check_int_positive | |||
class ModelCoverageMetrics: | |||
""" | |||
Evaluate the testing adequacy of a model fuzz test. | |||
Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep | |||
Learning Systems <https://arxiv.org/abs/1803.07519>`_ | |||
Args: | |||
model (Model): The pre-trained model which waiting for testing. | |||
k (int): The number of segmented sections of neurons' output intervals. | |||
n (int): The number of testing neurons. | |||
train_dataset (numpy.ndarray): Training dataset used for determine | |||
the neurons' output boundaries. | |||
""" | |||
def __init__(self, model, k, n, train_dataset): | |||
self._model = check_model('model', model, Model) | |||
self._k = k | |||
self._n = n | |||
train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
self._lower_bounds = [np.inf]*n | |||
self._upper_bounds = [-np.inf]*n | |||
self._var = [0]*n | |||
self._main_section_hits = [[0 for _ in range(self._k)] for _ in | |||
range(self._n)] | |||
self._lower_corner_hits = [0]*self._n | |||
self._upper_corner_hits = [0]*self._n | |||
self._bounds_get(train_dataset) | |||
def _bounds_get(self, train_dataset, batch_size=32): | |||
""" | |||
Update the lower and upper boundaries of neurons' outputs. | |||
Args: | |||
train_dataset (numpy.ndarray): Training dataset used for | |||
determine the neurons' output boundaries. | |||
batch_size (int): The number of samples in a predict batch. | |||
Default: 32. | |||
""" | |||
batch_size = check_int_positive('batch_size', batch_size) | |||
output_mat = [] | |||
batches = train_dataset.shape[0] // batch_size | |||
for i in range(batches): | |||
inputs = train_dataset[i*batch_size: (i + 1)*batch_size] | |||
output = self._model.predict(Tensor(inputs)).asnumpy() | |||
output_mat.append(output) | |||
lower_compare_array = np.concatenate( | |||
[output, np.array([self._lower_bounds])], axis=0) | |||
self._lower_bounds = np.min(lower_compare_array, axis=0) | |||
upper_compare_array = np.concatenate( | |||
[output, np.array([self._upper_bounds])], axis=0) | |||
self._upper_bounds = np.max(upper_compare_array, axis=0) | |||
self._var = np.std(np.concatenate(np.array(output_mat), axis=0), | |||
axis=0) | |||
def _sections_hits_count(self, dataset, intervals): | |||
""" | |||
Update the coverage matrix of neurons' output subsections. | |||
Args: | |||
dataset (numpy.ndarray): Testing data. | |||
intervals (list[float]): Segmentation intervals of neurons' | |||
outputs. | |||
""" | |||
dataset = check_numpy_param('dataset', dataset) | |||
batch_output = self._model.predict(Tensor(dataset)).asnumpy() | |||
batch_section_indexes = (batch_output - self._lower_bounds) // intervals | |||
for section_indexes in batch_section_indexes: | |||
for i in range(self._n): | |||
if section_indexes[i] < 0: | |||
self._lower_corner_hits[i] = 1 | |||
elif section_indexes[i] >= self._k: | |||
self._upper_corner_hits[i] = 1 | |||
else: | |||
self._main_section_hits[i][int(section_indexes[i])] = 1 | |||
def test_adequacy_coverage_calculate(self, dataset, bias_coefficient=0, | |||
batch_size=32): | |||
""" | |||
Calculate the testing adequacy of the given dataset. | |||
Args: | |||
dataset (numpy.ndarray): Data for fuzz test. | |||
bias_coefficient (float): The coefficient used for changing the | |||
neurons' output boundaries. Default: 0. | |||
batch_size (int): The number of samples in a predict batch. | |||
Default: 32. | |||
Examples: | |||
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||
>>> model_fuzz_test.test_adequacy_coverage_calculate(test_images) | |||
""" | |||
dataset = check_numpy_param('dataset', dataset) | |||
batch_size = check_int_positive('batch_size', batch_size) | |||
self._lower_bounds -= bias_coefficient*self._var | |||
self._upper_bounds += bias_coefficient*self._var | |||
intervals = (self._upper_bounds - self._lower_bounds) / self._k | |||
batches = dataset.shape[0] // batch_size | |||
for i in range(batches): | |||
self._sections_hits_count( | |||
dataset[i*batch_size: (i + 1)*batch_size], intervals) | |||
def get_kmnc(self): | |||
""" | |||
Get the metric of 'k-multisection neuron coverage'. | |||
Returns: | |||
float, the metric of 'k-multisection neuron coverage'. | |||
Examples: | |||
>>> model_fuzz_test.get_kmnc() | |||
""" | |||
kmnc = np.sum(self._main_section_hits) / (self._n*self._k) | |||
return kmnc | |||
def get_nbc(self): | |||
""" | |||
Get the metric of 'neuron boundary coverage'. | |||
Returns: | |||
float, the metric of 'neuron boundary coverage'. | |||
Examples: | |||
>>> model_fuzz_test.get_nbc() | |||
""" | |||
nbc = (np.sum(self._lower_corner_hits) + np.sum( | |||
self._upper_corner_hits)) / (2*self._n) | |||
return nbc | |||
def get_snac(self): | |||
""" | |||
Get the metric of 'strong neuron activation coverage'. | |||
Returns: | |||
float: the metric of 'strong neuron activation coverage'. | |||
Examples: | |||
>>> model_fuzz_test.get_snac() | |||
""" | |||
snac = np.sum(self._upper_corner_hits) / self._n | |||
return snac |
@@ -0,0 +1,128 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Model-fuzz coverage test. | |||
""" | |||
import numpy as np | |||
import pytest | |||
from mindspore.train import Model | |||
import mindspore.nn as nn | |||
from mindspore.nn import Cell | |||
from mindspore import context | |||
from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||
from mindarmour.attacks.gradient_method import FastGradientSignMethod | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Neuron coverage test' | |||
LOGGER.set_level('INFO') | |||
# for user | |||
class Net(Cell): | |||
""" | |||
Construct the network of target model. | |||
Examples: | |||
>>> net = Net() | |||
""" | |||
def __init__(self): | |||
""" | |||
Introduce the layers used for network construction. | |||
""" | |||
super(Net, self).__init__() | |||
self._relu = nn.ReLU() | |||
def construct(self, inputs): | |||
""" | |||
Construct network. | |||
Args: | |||
inputs (Tensor): Input data. | |||
""" | |||
out = self._relu(inputs) | |||
return out | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_lenet_mnist_coverage_cpu(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
# load network | |||
net = Net() | |||
model = Model(net) | |||
# initialize fuzz test with training dataset | |||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||
model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) | |||
# fuzz test with original test data | |||
# get test data | |||
test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||
test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | |||
model_fuzz_test.test_adequacy_coverage_calculate(test_data) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
# generate adv_data | |||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | |||
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||
model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||
bias_coefficient=0.5) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_lenet_mnist_coverage_ascend(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# load network | |||
net = Net() | |||
model = Model(net) | |||
# initialize fuzz test with training dataset | |||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||
model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) | |||
# fuzz test with original test data | |||
# get test data | |||
test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||
test_labels = np.random.randint(0, 10, 2000) | |||
test_labels = (np.eye(10)[test_labels]).astype(np.float32) | |||
model_fuzz_test.test_adequacy_coverage_calculate(test_data) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
# generate adv_data | |||
attack = FastGradientSignMethod(net, eps=0.3) | |||
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||
model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||
bias_coefficient=0.5) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) |