Merge pull request !12 from jxlang910/mastertags/v0.2.0-alpha
| @@ -0,0 +1,89 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import sys | |||||
| import numpy as np | |||||
| from mindspore import Model | |||||
| from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||||
| from mindarmour.attacks.gradient_method import FastGradientSignMethod | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||||
| from lenet5_net import LeNet5 | |||||
| sys.path.append("..") | |||||
| from data_processing import generate_mnist_dataset | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'Neuron coverage test' | |||||
| LOGGER.set_level('INFO') | |||||
| def test_lenet_mnist_coverage(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| # upload trained network | |||||
| ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| net = LeNet5() | |||||
| load_dict = load_checkpoint(ckpt_name) | |||||
| load_param_into_net(net, load_dict) | |||||
| model = Model(net) | |||||
| # get training data | |||||
| data_list = "./MNIST_unzip/train" | |||||
| batch_size = 32 | |||||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||||
| train_images = [] | |||||
| for data in ds.create_tuple_iterator(): | |||||
| images = data[0].astype(np.float32) | |||||
| train_images.append(images) | |||||
| train_images = np.concatenate(train_images, axis=0) | |||||
| # initialize fuzz test with training dataset | |||||
| model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||||
| # fuzz test with original test data | |||||
| # get test data | |||||
| data_list = "./MNIST_unzip/test" | |||||
| batch_size = 32 | |||||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||||
| test_images = [] | |||||
| test_labels = [] | |||||
| for data in ds.create_tuple_iterator(): | |||||
| images = data[0].astype(np.float32) | |||||
| labels = data[1] | |||||
| test_images.append(images) | |||||
| test_labels.append(labels) | |||||
| test_images = np.concatenate(test_images, axis=0) | |||||
| test_labels = np.concatenate(test_labels, axis=0) | |||||
| model_fuzz_test.test_adequacy_coverage_calculate(test_images) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||||
| # generate adv_data | |||||
| loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | |||||
| adv_data = attack.batch_generate(test_images, test_labels, batch_size=32) | |||||
| model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||||
| bias_coefficient=0.5) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||||
| if __name__ == '__main__': | |||||
| test_lenet_mnist_coverage() | |||||
| @@ -0,0 +1,3 @@ | |||||
| from .model_coverage_metrics import ModelCoverageMetrics | |||||
| __all__ = ['ModelCoverageMetrics'] | |||||
| @@ -0,0 +1,167 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Model-Test Coverage Metrics. | |||||
| """ | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore import Model | |||||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||||
| check_int_positive | |||||
| class ModelCoverageMetrics: | |||||
| """ | |||||
| Evaluate the testing adequacy of a model fuzz test. | |||||
| Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep | |||||
| Learning Systems <https://arxiv.org/abs/1803.07519>`_ | |||||
| Args: | |||||
| model (Model): The pre-trained model which waiting for testing. | |||||
| k (int): The number of segmented sections of neurons' output intervals. | |||||
| n (int): The number of testing neurons. | |||||
| train_dataset (numpy.ndarray): Training dataset used for determine | |||||
| the neurons' output boundaries. | |||||
| """ | |||||
| def __init__(self, model, k, n, train_dataset): | |||||
| self._model = check_model('model', model, Model) | |||||
| self._k = k | |||||
| self._n = n | |||||
| train_dataset = check_numpy_param('train_dataset', train_dataset) | |||||
| self._lower_bounds = [np.inf]*n | |||||
| self._upper_bounds = [-np.inf]*n | |||||
| self._var = [0]*n | |||||
| self._main_section_hits = [[0 for _ in range(self._k)] for _ in | |||||
| range(self._n)] | |||||
| self._lower_corner_hits = [0]*self._n | |||||
| self._upper_corner_hits = [0]*self._n | |||||
| self._bounds_get(train_dataset) | |||||
| def _bounds_get(self, train_dataset, batch_size=32): | |||||
| """ | |||||
| Update the lower and upper boundaries of neurons' outputs. | |||||
| Args: | |||||
| train_dataset (numpy.ndarray): Training dataset used for | |||||
| determine the neurons' output boundaries. | |||||
| batch_size (int): The number of samples in a predict batch. | |||||
| Default: 32. | |||||
| """ | |||||
| batch_size = check_int_positive('batch_size', batch_size) | |||||
| output_mat = [] | |||||
| batches = train_dataset.shape[0] // batch_size | |||||
| for i in range(batches): | |||||
| inputs = train_dataset[i*batch_size: (i + 1)*batch_size] | |||||
| output = self._model.predict(Tensor(inputs)).asnumpy() | |||||
| output_mat.append(output) | |||||
| lower_compare_array = np.concatenate( | |||||
| [output, np.array([self._lower_bounds])], axis=0) | |||||
| self._lower_bounds = np.min(lower_compare_array, axis=0) | |||||
| upper_compare_array = np.concatenate( | |||||
| [output, np.array([self._upper_bounds])], axis=0) | |||||
| self._upper_bounds = np.max(upper_compare_array, axis=0) | |||||
| self._var = np.std(np.concatenate(np.array(output_mat), axis=0), | |||||
| axis=0) | |||||
| def _sections_hits_count(self, dataset, intervals): | |||||
| """ | |||||
| Update the coverage matrix of neurons' output subsections. | |||||
| Args: | |||||
| dataset (numpy.ndarray): Testing data. | |||||
| intervals (list[float]): Segmentation intervals of neurons' | |||||
| outputs. | |||||
| """ | |||||
| dataset = check_numpy_param('dataset', dataset) | |||||
| batch_output = self._model.predict(Tensor(dataset)).asnumpy() | |||||
| batch_section_indexes = (batch_output - self._lower_bounds) // intervals | |||||
| for section_indexes in batch_section_indexes: | |||||
| for i in range(self._n): | |||||
| if section_indexes[i] < 0: | |||||
| self._lower_corner_hits[i] = 1 | |||||
| elif section_indexes[i] >= self._k: | |||||
| self._upper_corner_hits[i] = 1 | |||||
| else: | |||||
| self._main_section_hits[i][int(section_indexes[i])] = 1 | |||||
| def test_adequacy_coverage_calculate(self, dataset, bias_coefficient=0, | |||||
| batch_size=32): | |||||
| """ | |||||
| Calculate the testing adequacy of the given dataset. | |||||
| Args: | |||||
| dataset (numpy.ndarray): Data for fuzz test. | |||||
| bias_coefficient (float): The coefficient used for changing the | |||||
| neurons' output boundaries. Default: 0. | |||||
| batch_size (int): The number of samples in a predict batch. | |||||
| Default: 32. | |||||
| Examples: | |||||
| >>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||||
| >>> model_fuzz_test.test_adequacy_coverage_calculate(test_images) | |||||
| """ | |||||
| dataset = check_numpy_param('dataset', dataset) | |||||
| batch_size = check_int_positive('batch_size', batch_size) | |||||
| self._lower_bounds -= bias_coefficient*self._var | |||||
| self._upper_bounds += bias_coefficient*self._var | |||||
| intervals = (self._upper_bounds - self._lower_bounds) / self._k | |||||
| batches = dataset.shape[0] // batch_size | |||||
| for i in range(batches): | |||||
| self._sections_hits_count( | |||||
| dataset[i*batch_size: (i + 1)*batch_size], intervals) | |||||
| def get_kmnc(self): | |||||
| """ | |||||
| Get the metric of 'k-multisection neuron coverage'. | |||||
| Returns: | |||||
| float, the metric of 'k-multisection neuron coverage'. | |||||
| Examples: | |||||
| >>> model_fuzz_test.get_kmnc() | |||||
| """ | |||||
| kmnc = np.sum(self._main_section_hits) / (self._n*self._k) | |||||
| return kmnc | |||||
| def get_nbc(self): | |||||
| """ | |||||
| Get the metric of 'neuron boundary coverage'. | |||||
| Returns: | |||||
| float, the metric of 'neuron boundary coverage'. | |||||
| Examples: | |||||
| >>> model_fuzz_test.get_nbc() | |||||
| """ | |||||
| nbc = (np.sum(self._lower_corner_hits) + np.sum( | |||||
| self._upper_corner_hits)) / (2*self._n) | |||||
| return nbc | |||||
| def get_snac(self): | |||||
| """ | |||||
| Get the metric of 'strong neuron activation coverage'. | |||||
| Returns: | |||||
| float: the metric of 'strong neuron activation coverage'. | |||||
| Examples: | |||||
| >>> model_fuzz_test.get_snac() | |||||
| """ | |||||
| snac = np.sum(self._upper_corner_hits) / self._n | |||||
| return snac | |||||
| @@ -0,0 +1,128 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Model-fuzz coverage test. | |||||
| """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| from mindspore.train import Model | |||||
| import mindspore.nn as nn | |||||
| from mindspore.nn import Cell | |||||
| from mindspore import context | |||||
| from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||||
| from mindarmour.attacks.gradient_method import FastGradientSignMethod | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'Neuron coverage test' | |||||
| LOGGER.set_level('INFO') | |||||
| # for user | |||||
| class Net(Cell): | |||||
| """ | |||||
| Construct the network of target model. | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| """ | |||||
| def __init__(self): | |||||
| """ | |||||
| Introduce the layers used for network construction. | |||||
| """ | |||||
| super(Net, self).__init__() | |||||
| self._relu = nn.ReLU() | |||||
| def construct(self, inputs): | |||||
| """ | |||||
| Construct network. | |||||
| Args: | |||||
| inputs (Tensor): Input data. | |||||
| """ | |||||
| out = self._relu(inputs) | |||||
| return out | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_card | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_lenet_mnist_coverage_cpu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| # load network | |||||
| net = Net() | |||||
| model = Model(net) | |||||
| # initialize fuzz test with training dataset | |||||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||||
| model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) | |||||
| # fuzz test with original test data | |||||
| # get test data | |||||
| test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||||
| test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | |||||
| model_fuzz_test.test_adequacy_coverage_calculate(test_data) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||||
| # generate adv_data | |||||
| loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | |||||
| adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||||
| model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||||
| bias_coefficient=0.5) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_card | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_lenet_mnist_coverage_ascend(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| # load network | |||||
| net = Net() | |||||
| model = Model(net) | |||||
| # initialize fuzz test with training dataset | |||||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||||
| model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) | |||||
| # fuzz test with original test data | |||||
| # get test data | |||||
| test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||||
| test_labels = np.random.randint(0, 10, 2000) | |||||
| test_labels = (np.eye(10)[test_labels]).astype(np.float32) | |||||
| model_fuzz_test.test_adequacy_coverage_calculate(test_data) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||||
| # generate adv_data | |||||
| attack = FastGradientSignMethod(net, eps=0.3) | |||||
| adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||||
| model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||||
| bias_coefficient=0.5) | |||||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||||
| LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||||
| LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||||