@@ -1,8 +1,8 @@ | |||
""" | |||
This module includes various metrics to fuzzing the test of DNN. | |||
""" | |||
from .fuzzing import Fuzzing | |||
from .fuzzing import Fuzzer | |||
from .model_coverage_metrics import ModelCoverageMetrics | |||
__all__ = ['Fuzzing', | |||
__all__ = ['Fuzzer', | |||
'ModelCoverageMetrics'] |
@@ -23,11 +23,11 @@ from mindspore import Tensor | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
check_int_positive | |||
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
Translate, Scale, Shear, Rotate | |||
class Fuzzing: | |||
class Fuzzer: | |||
""" | |||
Fuzzing test framework for deep neural networks. | |||
@@ -84,7 +84,7 @@ class Fuzzing: | |||
[]) | |||
transform = strages[trans_strage]( | |||
self._image_value_expand(seed), self.mode) | |||
transform.random_param() | |||
transform.set_params(auto_param=True) | |||
mutate_test = transform.transform() | |||
mutate_test = np.expand_dims( | |||
self._image_value_compress(mutate_test), 0) | |||
@@ -138,7 +138,7 @@ class Fuzzing: | |||
result = result.asnumpy() | |||
for index in range(len(mutate_tests)): | |||
mutate = np.expand_dims(mutate_tests[index], 0) | |||
self.coverage_metrics.test_adequacy_coverage_calculate( | |||
self.coverage_metrics.model_coverage_test( | |||
mutate.astype(np.float32), batch_size=1) | |||
if coverage_metric == "KMNC": | |||
coverages.append(self.coverage_metrics.get_kmnc()) | |||
@@ -0,0 +1,569 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Image transform | |||
""" | |||
import numpy as np | |||
from PIL import Image, ImageEnhance, ImageFilter | |||
from mindspore.dataset.transforms.vision.py_transforms_util import is_numpy, \ | |||
to_pil, hwc_to_chw | |||
from mindarmour.utils._check_param import check_param_multi_types | |||
from mindarmour.utils.logger import LogUtil | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'ModelCoverageMetrics' | |||
def chw_to_hwc(img): | |||
""" | |||
Transpose the input image; shape (C, H, W) to shape (H, W, C). | |||
Args: | |||
img (numpy.ndarray): Image to be converted. | |||
Returns: | |||
img (numpy.ndarray), Converted image. | |||
""" | |||
if is_numpy(img): | |||
return img.transpose(1, 2, 0).copy() | |||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||
def is_hwc(img): | |||
""" | |||
Check if the input image is shape (H, W, C). | |||
Args: | |||
img (numpy.ndarray): Image to be checked. | |||
Returns: | |||
Bool, True if input is shape (H, W, C). | |||
""" | |||
if is_numpy(img): | |||
img_shape = np.shape(img) | |||
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3: | |||
return True | |||
return False | |||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||
def is_chw(img): | |||
""" | |||
Check if the input image is shape (H, W, C). | |||
Args: | |||
img (numpy.ndarray): Image to be checked. | |||
Returns: | |||
Bool, True if input is shape (H, W, C). | |||
""" | |||
if is_numpy(img): | |||
img_shape = np.shape(img) | |||
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3: | |||
return True | |||
return False | |||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||
def is_rgb(img): | |||
""" | |||
Check if the input image is RGB. | |||
Args: | |||
img (numpy.ndarray): Image to be checked. | |||
Returns: | |||
Bool, True if input is RGB. | |||
""" | |||
if is_numpy(img): | |||
if len(np.shape(img)) == 3: | |||
return True | |||
return False | |||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||
def is_normalized(img): | |||
""" | |||
Check if the input image is normalized between 0 to 1. | |||
Args: | |||
img (numpy.ndarray): Image to be checked. | |||
Returns: | |||
Bool, True if input is normalized between 0 to 1. | |||
""" | |||
if is_numpy(img): | |||
minimal = np.min(img) | |||
maximun = np.max(img) | |||
if minimal >= 0 and maximun <= 1: | |||
return True | |||
return False | |||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||
class ImageTransform: | |||
""" | |||
The abstract base class for all image transform classes. | |||
""" | |||
def __init__(self): | |||
pass | |||
def _check(self, image): | |||
""" Check image format. If input image is RGB and its shape | |||
is (C, H, W), it will be transposed to (H, W, C). If the value | |||
of the image is not normalized , it will be normalized between 0 to 1.""" | |||
rgb = is_rgb(image) | |||
chw = False | |||
normalized = is_normalized(image) | |||
if rgb: | |||
chw = is_chw(image) | |||
if chw: | |||
image = chw_to_hwc(image) | |||
else: | |||
image = image | |||
else: | |||
image = image | |||
if normalized: | |||
image = np.uint8(image*255) | |||
return rgb, chw, normalized, image | |||
def _original_format(self, image, chw, normalized): | |||
""" Return transformed image with original format. """ | |||
if not is_numpy(image): | |||
image = np.array(image) | |||
if chw: | |||
image = hwc_to_chw(image) | |||
if normalized: | |||
image = image / 255 | |||
return image | |||
def transform(self, image): | |||
pass | |||
class Contrast(ImageTransform): | |||
""" | |||
Contrast of an image. | |||
Args: | |||
factor ([float, int]): Control the contrast of an image. If 1.0 gives the | |||
original image. If 0 gives a gray image. Default: 1. | |||
""" | |||
def __init__(self, factor=1): | |||
super(Contrast, self).__init__() | |||
self.set_params(factor) | |||
def set_params(self, factor=1, auto_param=False): | |||
""" | |||
Set contrast parameters. | |||
Args: | |||
factor ([float, int]): Control the contrast of an image. If 1.0 gives | |||
the original image. If 0 gives a gray image. Default: 1. | |||
auto_param (bool): True if auto generate parameters. Default: False. | |||
""" | |||
if auto_param: | |||
self.factor = np.random.uniform(-5, 5) | |||
else: | |||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||
def transform(self, image): | |||
""" | |||
Transform the image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
image = to_pil(image) | |||
img_contrast = ImageEnhance.Contrast(image) | |||
trans_image = img_contrast.enhance(self.factor) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
return trans_image | |||
class Brightness(ImageTransform): | |||
""" | |||
Brightness of an image. | |||
Args: | |||
factor ([float, int]): Control the brightness of an image. If 1.0 gives | |||
the original image. If 0 gives a black image. Default: 1. | |||
""" | |||
def __init__(self, factor=1): | |||
super(Brightness, self).__init__() | |||
self.set_params(factor) | |||
def set_params(self, factor=1, auto_param=False): | |||
""" | |||
Set brightness parameters. | |||
Args: | |||
factor ([float, int]): Control the brightness of an image. If 1 | |||
gives the original image. If 0 gives a black image. Default: 1. | |||
auto_param (bool): True if auto generate parameters. Default: False. | |||
""" | |||
if auto_param: | |||
self.factor = np.random.uniform(0, 5) | |||
else: | |||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||
def transform(self, image): | |||
""" | |||
Transform the image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
image = to_pil(image) | |||
img_contrast = ImageEnhance.Brightness(image) | |||
trans_image = img_contrast.enhance(self.factor) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
return trans_image | |||
class Blur(ImageTransform): | |||
""" | |||
Blurs the image using Gaussian blur filter. | |||
Args: | |||
radius([float, int]): Blur radius, 0 means no blur. Default: 0. | |||
""" | |||
def __init__(self, radius=0): | |||
super(Blur, self).__init__() | |||
self.set_params(radius) | |||
def set_params(self, radius=0, auto_param=False): | |||
""" | |||
Set blur parameters. | |||
Args: | |||
radius ([float, int]): Blur radius, 0 means no blur. Default: 0. | |||
auto_param (bool): True if auto generate parameters. Default: False. | |||
""" | |||
if auto_param: | |||
self.radius = np.random.uniform(-1.5, 1.5) | |||
else: | |||
self.radius = check_param_multi_types('radius', radius, [int, float]) | |||
def transform(self, image): | |||
""" | |||
Transform the image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
image = to_pil(image) | |||
trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
return trans_image | |||
class Noise(ImageTransform): | |||
""" | |||
Add noise of an image. | |||
Args: | |||
factor (float): 1 - factor is the ratio of pixels to add noise. | |||
If 0 gives the original image. Default 0. | |||
""" | |||
def __init__(self, factor=0): | |||
super(Noise, self).__init__() | |||
self.set_params(factor) | |||
def set_params(self, factor=0, auto_param=False): | |||
""" | |||
Set noise parameters. | |||
Args: | |||
factor ([float, int]): 1 - factor is the ratio of pixels to add noise. | |||
If 0 gives the original image. Default 0. | |||
auto_param (bool): True if auto generate parameters. Default: False. | |||
""" | |||
if auto_param: | |||
self.factor = np.random.uniform(0.7, 1) | |||
else: | |||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||
def transform(self, image): | |||
""" | |||
Transform the image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) | |||
trans_image = np.copy(image) | |||
trans_image[noise < -self.factor] = 0 | |||
trans_image[noise > self.factor] = 1 | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
return trans_image | |||
class Translate(ImageTransform): | |||
""" | |||
Translate an image. | |||
Args: | |||
x_bias ([int, float): X-direction translation, x=x+x_bias. Default: 0. | |||
y_bias ([int, float): Y-direction translation, y=y+y_bias. Default: 0. | |||
""" | |||
def __init__(self, x_bias=0, y_bias=0): | |||
super(Translate, self).__init__() | |||
self.set_params(x_bias, y_bias) | |||
def set_params(self, x_bias=0, y_bias=0, auto_param=False): | |||
""" | |||
Set translate parameters. | |||
Args: | |||
x_bias ([float, int]): X-direction translation, x=x+x_bias. Default: 0. | |||
y_bias ([float, int]): Y-direction translation, y=y+y_bias. Default: 0. | |||
auto_param (bool): True if auto generate parameters. Default: False. | |||
""" | |||
self.auto_param = auto_param | |||
if auto_param: | |||
self.x_bias = np.random.uniform(-0.3, 0.3) | |||
self.y_bias = np.random.uniform(-0.3, 0.3) | |||
else: | |||
self.x_bias = check_param_multi_types('x_bias', x_bias, | |||
[int, float]) | |||
self.y_bias = check_param_multi_types('y_bias', y_bias, | |||
[int, float]) | |||
def transform(self, image): | |||
""" | |||
Transform the image. | |||
Args: | |||
image(numpy.ndarray): Original image to be transformed. | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
img = to_pil(image) | |||
if self.auto_param: | |||
image_shape = np.shape(image) | |||
self.x_bias = image_shape[0]*self.x_bias | |||
self.y_bias = image_shape[1]*self.y_bias | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(1, 0, self.x_bias, 0, 1, self.y_bias)) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
return trans_image | |||
class Scale(ImageTransform): | |||
""" | |||
Scale an image in the middle. | |||
Args: | |||
factor_x ([float, int]): Rescale in X-direction, x=factor_x*x. Default: 1. | |||
factor_y ([float, int]): Rescale in Y-direction, y=factor_y*y. Default: 1. | |||
""" | |||
def __init__(self, factor_x=1, factor_y=1): | |||
super(Scale, self).__init__() | |||
self.set_params(factor_x, factor_y) | |||
def set_params(self, factor_x=1, factor_y=1, auto_param=False): | |||
""" | |||
Set scale parameters. | |||
Args: | |||
factor_x ([float, int]): Rescale in X-direction, x=factor_x*x. | |||
Default: 1. | |||
factor_y ([float, int]): Rescale in Y-direction, y=factor_y*y. | |||
Default: 1. | |||
auto_param (bool): True if auto generate parameters. Default: False. | |||
""" | |||
if auto_param: | |||
self.factor_x = np.random.uniform(0.7, 3) | |||
self.factor_y = np.random.uniform(0.7, 3) | |||
else: | |||
self.factor_x = check_param_multi_types('factor_x', factor_x, | |||
[int, float]) | |||
self.factor_y = check_param_multi_types('factor_y', factor_y, | |||
[int, float]) | |||
def transform(self, image): | |||
""" | |||
Transform the image. | |||
Args: | |||
image(numpy.ndarray): Original image to be transformed. | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
rgb, chw, normalized, image = self._check(image) | |||
if rgb: | |||
h, w, _ = np.shape(image) | |||
else: | |||
h, w = np.shape(image) | |||
move_x_centor = w / 2*(1 - self.factor_x) | |||
move_y_centor = h / 2*(1 - self.factor_y) | |||
img = to_pil(image) | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(self.factor_x, 0, move_x_centor, | |||
0, self.factor_y, move_y_centor)) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
return trans_image | |||
class Shear(ImageTransform): | |||
""" | |||
Shear an image, for each pixel (x, y) in the sheared image, the new value is | |||
taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then | |||
the sheared image will be rescaled to fit original size. | |||
Args: | |||
factor_x ([float, int]): Shear factor of horizontal direction. Default: 0. | |||
factor_y ([float, int]): Shear factor of vertical direction. Default: 0. | |||
""" | |||
def __init__(self, factor_x=0, factor_y=0): | |||
super(Shear, self).__init__() | |||
self.set_params(factor_x, factor_y) | |||
def set_params(self, factor_x=0, factor_y=0, auto_param=False): | |||
""" | |||
Set shear parameters. | |||
Args: | |||
factor_x ([float, int]): Shear factor of horizontal direction. | |||
Default: 0. | |||
factor_y ([float, int]): Shear factor of vertical direction. | |||
Default: 0. | |||
auto_param (bool): True if auto generate parameters. Default: False. | |||
""" | |||
if factor_x != 0 and factor_y != 0: | |||
msg = 'factor_x and factor_y can not be both more than 0.' | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
if auto_param: | |||
if np.random.uniform(-1, 1) > 0: | |||
self.factor_x = np.random.uniform(-2, 2) | |||
self.factor_y = 0 | |||
else: | |||
self.factor_x = 0 | |||
self.factor_y = np.random.uniform(-2, 2) | |||
else: | |||
self.factor_x = check_param_multi_types('factor', factor_x, | |||
[int, float]) | |||
self.factor_y = check_param_multi_types('factor', factor_y, | |||
[int, float]) | |||
def transform(self, image): | |||
""" | |||
Transform the image. | |||
Args: | |||
image(numpy.ndarray): Original image to be transformed. | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
rgb, chw, normalized, image = self._check(image) | |||
img = to_pil(image) | |||
if rgb: | |||
h, w, _ = np.shape(image) | |||
else: | |||
h, w = np.shape(image) | |||
if self.factor_x != 0: | |||
boarder_x = [0, -w, -self.factor_x*h, -w - self.factor_x*h] | |||
min_x = min(boarder_x) | |||
max_x = max(boarder_x) | |||
scale = (max_x - min_x) / w | |||
move_x_cen = (w - scale*w - scale*h*self.factor_x) / 2 | |||
move_y_cen = h*(1 - scale) / 2 | |||
else: | |||
boarder_y = [0, -h, -self.factor_y*w, -h - self.factor_y*w] | |||
min_y = min(boarder_y) | |||
max_y = max(boarder_y) | |||
scale = (max_y - min_y) / h | |||
move_y_cen = (h - scale*h - scale*w*self.factor_y) / 2 | |||
move_x_cen = w*(1 - scale) / 2 | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(scale, scale*self.factor_x, move_x_cen, | |||
scale*self.factor_y, scale, move_y_cen)) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
return trans_image | |||
class Rotate(ImageTransform): | |||
""" | |||
Rotate an image of degrees counter clockwise around its center. | |||
Args: | |||
angle([float, int]): Degrees counter clockwise. Default: 0. | |||
""" | |||
def __init__(self, angle=0): | |||
super(Rotate, self).__init__() | |||
self.set_params(angle) | |||
def set_params(self, angle=0, auto_param=False): | |||
""" | |||
Set rotate parameters. | |||
Args: | |||
angle([float, int]): Degrees counter clockwise. Default: 0. | |||
auto_param (bool): True if auto generate parameters. Default: False. | |||
""" | |||
if auto_param: | |||
self.angle = np.random.uniform(0, 360) | |||
else: | |||
self.angle = check_param_multi_types('angle', angle, [int, float]) | |||
def transform(self, image): | |||
""" | |||
Transform the image. | |||
Args: | |||
image(numpy.ndarray): Original image to be transformed. | |||
Returns: | |||
numpy.ndarray, transformed image. | |||
""" | |||
_, chw, normalized, image = self._check(image) | |||
img = to_pil(image) | |||
trans_image = img.rotate(self.angle, expand=True) | |||
trans_image = self._original_format(trans_image, chw, normalized) | |||
return trans_image |
@@ -133,8 +133,7 @@ class ModelCoverageMetrics: | |||
else: | |||
self._main_section_hits[i][int(section_indexes[i])] = 1 | |||
def test_adequacy_coverage_calculate(self, dataset, bias_coefficient=0, | |||
batch_size=32): | |||
def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): | |||
""" | |||
Calculate the testing adequacy of the given dataset. | |||
@@ -147,7 +146,7 @@ class ModelCoverageMetrics: | |||
Examples: | |||
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||
>>> model_fuzz_test.test_adequacy_coverage_calculate(test_images) | |||
>>> model_fuzz_test.calculate_coverage(test_images) | |||
""" | |||
dataset = check_numpy_param('dataset', dataset) | |||
batch_size = check_int_positive('batch_size', batch_size) | |||
@@ -1,267 +0,0 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Image transform | |||
""" | |||
import numpy as np | |||
from PIL import Image, ImageEnhance, ImageFilter | |||
import random | |||
from mindarmour.utils._check_param import check_numpy_param | |||
class ImageTransform: | |||
""" | |||
The abstract base class for all image transform classes. | |||
""" | |||
def __init__(self): | |||
pass | |||
def random_param(self): | |||
pass | |||
def transform(self): | |||
pass | |||
class Contrast(ImageTransform): | |||
""" | |||
Contrast of an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Contrast, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.factor = random.uniform(-5, 5) | |||
def transform(self): | |||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||
img_contrast = ImageEnhance.Contrast(img) | |||
trans_image = img_contrast.enhance(self.factor) | |||
trans_image = np.array(trans_image)/255 | |||
return trans_image | |||
class Brightness(ImageTransform): | |||
""" | |||
Brightness of an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Brightness, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.factor = random.uniform(0, 5) | |||
def transform(self): | |||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||
img_contrast = ImageEnhance.Brightness(img) | |||
trans_image = img_contrast.enhance(self.factor) | |||
trans_image = np.array(trans_image)/255 | |||
return trans_image | |||
class Blur(ImageTransform): | |||
""" | |||
GaussianBlur of an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Blur, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.radius = random.uniform(-1.5, 1.5) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||
trans_image = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||
trans_image = np.array(trans_image)/255 | |||
return trans_image | |||
class Noise(ImageTransform): | |||
""" | |||
Add noise of an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Noise, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" random generate parameters """ | |||
self.factor = random.uniform(0.7, 1) | |||
def transform(self): | |||
""" Random generate parameters. """ | |||
noise = np.random.uniform(low=-1, high=1, size=self.image.shape) | |||
trans_image = np.copy(self.image) | |||
trans_image[noise < -self.factor] = 0 | |||
trans_image[noise > self.factor] = 1 | |||
trans_image = np.array(trans_image) | |||
return trans_image | |||
class Translate(ImageTransform): | |||
""" | |||
Translate an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Translate, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
image_shape = np.shape(self.image) | |||
self.x_bias = random.uniform(-image_shape[0]/3, image_shape[0]/3) | |||
self.y_bias = random.uniform(-image_shape[1]/3, image_shape[1]/3) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(1, 0, self.x_bias, 0, 1, self.y_bias)) | |||
trans_image = np.array(trans_image)/255 | |||
return trans_image | |||
class Scale(ImageTransform): | |||
""" | |||
Scale an image. | |||
Args: | |||
image(numpy.ndarray): Original image to be transformed. | |||
mode(str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Scale, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.factor_x = random.uniform(0.7, 2) | |||
self.factor_y = random.uniform(0.7, 2) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(self.factor_x, 0, 0, 0, self.factor_y, 0)) | |||
trans_image = np.array(trans_image)/255 | |||
return trans_image | |||
class Shear(ImageTransform): | |||
""" | |||
Shear an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Shear, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.factor = random.uniform(0, 1) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||
if np.random.random() > 0.5: | |||
level = -self.factor | |||
else: | |||
level = self.factor | |||
if np.random.random() > 0.5: | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(1, level, 0, 0, 1, 0)) | |||
else: | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(1, 0, 0, level, 1, 0)) | |||
trans_image = np.array(trans_image, dtype=np.float)/255 | |||
return trans_image | |||
class Rotate(ImageTransform): | |||
""" | |||
Rotate an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Rotate, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.angle = random.uniform(0, 360) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||
trans_image = img.rotate(self.angle) | |||
trans_image = np.array(trans_image)/255 | |||
return trans_image |
@@ -77,7 +77,7 @@ def test_lenet_mnist_coverage_cpu(): | |||
# get test data | |||
test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||
test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | |||
model_fuzz_test.test_adequacy_coverage_calculate(test_data) | |||
model_fuzz_test.calculate_coverage(test_data) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
@@ -86,8 +86,7 @@ def test_lenet_mnist_coverage_cpu(): | |||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | |||
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||
model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||
bias_coefficient=0.5) | |||
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
@@ -113,7 +112,7 @@ def test_lenet_mnist_coverage_ascend(): | |||
test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||
test_labels = np.random.randint(0, 10, 2000) | |||
test_labels = (np.eye(10)[test_labels]).astype(np.float32) | |||
model_fuzz_test.test_adequacy_coverage_calculate(test_data) | |||
model_fuzz_test.calculate_coverage(test_data) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
@@ -121,8 +120,7 @@ def test_lenet_mnist_coverage_ascend(): | |||
# generate adv_data | |||
attack = FastGradientSignMethod(net, eps=0.3) | |||
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||
model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||
bias_coefficient=0.5) | |||
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) |
@@ -1,160 +0,0 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Model-fuzz coverage test. | |||
""" | |||
import numpy as np | |||
import pytest | |||
from mindspore import context | |||
from mindspore import nn | |||
from mindspore.common.initializer import TruncatedNormal | |||
from mindspore.ops import operations as P | |||
from mindspore.train import Model | |||
from mindarmour.fuzzing.fuzzing import Fuzzing | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
from mindarmour.utils.logger import LogUtil | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Fuzzing test' | |||
LOGGER.set_level('INFO') | |||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
weight = weight_variable() | |||
return nn.Conv2d(in_channels, out_channels, | |||
kernel_size=kernel_size, stride=stride, padding=padding, | |||
weight_init=weight, has_bias=False, pad_mode="valid") | |||
def fc_with_initialize(input_channels, out_channels): | |||
weight = weight_variable() | |||
bias = weight_variable() | |||
return nn.Dense(input_channels, out_channels, weight, bias) | |||
def weight_variable(): | |||
return TruncatedNormal(0.02) | |||
class Net(nn.Cell): | |||
""" | |||
Lenet network | |||
""" | |||
def __init__(self): | |||
super(Net, self).__init__() | |||
self.conv1 = conv(1, 6, 5) | |||
self.conv2 = conv(6, 16, 5) | |||
self.fc1 = fc_with_initialize(16*5*5, 120) | |||
self.fc2 = fc_with_initialize(120, 84) | |||
self.fc3 = fc_with_initialize(84, 10) | |||
self.relu = nn.ReLU() | |||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
self.reshape = P.Reshape() | |||
def construct(self, x): | |||
x = self.conv1(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
x = self.conv2(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
x = self.reshape(x, (-1, 16*5*5)) | |||
x = self.fc1(x) | |||
x = self.relu(x) | |||
x = self.fc2(x) | |||
x = self.relu(x) | |||
x = self.fc3(x) | |||
return x | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_fuzzing_ascend(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# load network | |||
net = Net() | |||
model = Model(net) | |||
batch_size = 8 | |||
num_classe = 10 | |||
# initialize fuzz test with training dataset | |||
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||
# fuzz test with original test data | |||
# get test data | |||
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
initial_seeds = [] | |||
for img, label in zip(test_data, test_labels): | |||
initial_seeds.append([img, label, 0]) | |||
model_coverage_test.test_adequacy_coverage_calculate( | |||
np.array(test_data).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||
max_seed_num=10) | |||
failed_tests = model_fuzz_test.fuzzing() | |||
if failed_tests: | |||
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
else: | |||
LOGGER.info(TAG, 'Fuzzing test identifies none failed test') | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_fuzzing_CPU(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
# load network | |||
net = Net() | |||
model = Model(net) | |||
batch_size = 8 | |||
num_classe = 10 | |||
# initialize fuzz test with training dataset | |||
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||
# fuzz test with original test data | |||
# get test data | |||
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
initial_seeds = [] | |||
for img, label in zip(test_data, test_labels): | |||
initial_seeds.append([img, label, 0]) | |||
model_coverage_test.test_adequacy_coverage_calculate( | |||
np.array(test_data).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||
max_seed_num=10) | |||
failed_tests = model_fuzz_test.fuzzing() | |||
if failed_tests: | |||
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
else: | |||
LOGGER.info(TAG, 'Fuzzing test identifies none failed test') |
@@ -18,7 +18,7 @@ import numpy as np | |||
import pytest | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
Translate, Scale, Shear, Rotate | |||
LOGGER = LogUtil.get_instance() | |||
@@ -31,11 +31,10 @@ LOGGER.set_level('INFO') | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_contrast(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Contrast(image, mode) | |||
trans.random_param() | |||
_ = trans.transform() | |||
image = (np.random.rand(32, 32)).astype(np.float32) | |||
trans = Contrast() | |||
trans.set_params(auto_param=True) | |||
_ = trans.transform(image) | |||
@pytest.mark.level0 | |||
@@ -43,11 +42,10 @@ def test_contrast(): | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_brightness(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Brightness(image, mode) | |||
trans.random_param() | |||
_ = trans.transform() | |||
image = (np.random.rand(32, 32)).astype(np.float32) | |||
trans = Brightness() | |||
trans.set_params(auto_param=True) | |||
_ = trans.transform(image) | |||
@pytest.mark.level0 | |||
@@ -57,11 +55,10 @@ def test_brightness(): | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_blur(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Blur(image, mode) | |||
trans.random_param() | |||
_ = trans.transform() | |||
image = (np.random.rand(32, 32)).astype(np.float32) | |||
trans = Blur() | |||
trans.set_params(auto_param=True) | |||
_ = trans.transform(image) | |||
@pytest.mark.level0 | |||
@@ -71,11 +68,10 @@ def test_blur(): | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_noise(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Noise(image, mode) | |||
trans.random_param() | |||
_ = trans.transform() | |||
image = (np.random.rand(32, 32)).astype(np.float32) | |||
trans = Noise() | |||
trans.set_params(auto_param=True) | |||
_ = trans.transform(image) | |||
@pytest.mark.level0 | |||
@@ -85,11 +81,10 @@ def test_noise(): | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_translate(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Translate(image, mode) | |||
trans.random_param() | |||
_ = trans.transform() | |||
image = (np.random.rand(32, 32)).astype(np.float32) | |||
trans = Translate() | |||
trans.set_params(auto_param=True) | |||
_ = trans.transform(image) | |||
@pytest.mark.level0 | |||
@@ -99,11 +94,10 @@ def test_translate(): | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_shear(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Shear(image, mode) | |||
trans.random_param() | |||
_ = trans.transform() | |||
image = (np.random.rand(32, 32)).astype(np.float32) | |||
trans = Shear() | |||
trans.set_params(auto_param=True) | |||
_ = trans.transform(image) | |||
@pytest.mark.level0 | |||
@@ -113,11 +107,10 @@ def test_shear(): | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_scale(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Scale(image, mode) | |||
trans.random_param() | |||
_ = trans.transform() | |||
image = (np.random.rand(32, 32)).astype(np.float32) | |||
trans = Scale() | |||
trans.set_params(auto_param=True) | |||
_ = trans.transform(image) | |||
@pytest.mark.level0 | |||
@@ -127,8 +120,7 @@ def test_scale(): | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_rotate(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Rotate(image, mode) | |||
trans.random_param() | |||
_ = trans.transform() | |||
image = (np.random.rand(32, 32)).astype(np.float32) | |||
trans = Rotate() | |||
trans.set_params(auto_param=True) | |||
_ = trans.transform(image) |