@@ -1,8 +1,8 @@ | |||||
""" | """ | ||||
This module includes various metrics to fuzzing the test of DNN. | This module includes various metrics to fuzzing the test of DNN. | ||||
""" | """ | ||||
from .fuzzing import Fuzzing | |||||
from .fuzzing import Fuzzer | |||||
from .model_coverage_metrics import ModelCoverageMetrics | from .model_coverage_metrics import ModelCoverageMetrics | ||||
__all__ = ['Fuzzing', | |||||
__all__ = ['Fuzzer', | |||||
'ModelCoverageMetrics'] | 'ModelCoverageMetrics'] |
@@ -23,11 +23,11 @@ from mindspore import Tensor | |||||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | ||||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, \ | ||||
check_int_positive | check_int_positive | ||||
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||||
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \ | |||||
Translate, Scale, Shear, Rotate | Translate, Scale, Shear, Rotate | ||||
class Fuzzing: | |||||
class Fuzzer: | |||||
""" | """ | ||||
Fuzzing test framework for deep neural networks. | Fuzzing test framework for deep neural networks. | ||||
@@ -84,7 +84,7 @@ class Fuzzing: | |||||
[]) | []) | ||||
transform = strages[trans_strage]( | transform = strages[trans_strage]( | ||||
self._image_value_expand(seed), self.mode) | self._image_value_expand(seed), self.mode) | ||||
transform.random_param() | |||||
transform.set_params(auto_param=True) | |||||
mutate_test = transform.transform() | mutate_test = transform.transform() | ||||
mutate_test = np.expand_dims( | mutate_test = np.expand_dims( | ||||
self._image_value_compress(mutate_test), 0) | self._image_value_compress(mutate_test), 0) | ||||
@@ -138,7 +138,7 @@ class Fuzzing: | |||||
result = result.asnumpy() | result = result.asnumpy() | ||||
for index in range(len(mutate_tests)): | for index in range(len(mutate_tests)): | ||||
mutate = np.expand_dims(mutate_tests[index], 0) | mutate = np.expand_dims(mutate_tests[index], 0) | ||||
self.coverage_metrics.test_adequacy_coverage_calculate( | |||||
self.coverage_metrics.model_coverage_test( | |||||
mutate.astype(np.float32), batch_size=1) | mutate.astype(np.float32), batch_size=1) | ||||
if coverage_metric == "KMNC": | if coverage_metric == "KMNC": | ||||
coverages.append(self.coverage_metrics.get_kmnc()) | coverages.append(self.coverage_metrics.get_kmnc()) | ||||
@@ -0,0 +1,569 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image transform | |||||
""" | |||||
import numpy as np | |||||
from PIL import Image, ImageEnhance, ImageFilter | |||||
from mindspore.dataset.transforms.vision.py_transforms_util import is_numpy, \ | |||||
to_pil, hwc_to_chw | |||||
from mindarmour.utils._check_param import check_param_multi_types | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'ModelCoverageMetrics' | |||||
def chw_to_hwc(img): | |||||
""" | |||||
Transpose the input image; shape (C, H, W) to shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be converted. | |||||
Returns: | |||||
img (numpy.ndarray), Converted image. | |||||
""" | |||||
if is_numpy(img): | |||||
return img.transpose(1, 2, 0).copy() | |||||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||||
def is_hwc(img): | |||||
""" | |||||
Check if the input image is shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is shape (H, W, C). | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||||
def is_chw(img): | |||||
""" | |||||
Check if the input image is shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is shape (H, W, C). | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||||
def is_rgb(img): | |||||
""" | |||||
Check if the input image is RGB. | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is RGB. | |||||
""" | |||||
if is_numpy(img): | |||||
if len(np.shape(img)) == 3: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||||
def is_normalized(img): | |||||
""" | |||||
Check if the input image is normalized between 0 to 1. | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is normalized between 0 to 1. | |||||
""" | |||||
if is_numpy(img): | |||||
minimal = np.min(img) | |||||
maximun = np.max(img) | |||||
if minimal >= 0 and maximun <= 1: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||||
class ImageTransform: | |||||
""" | |||||
The abstract base class for all image transform classes. | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
def _check(self, image): | |||||
""" Check image format. If input image is RGB and its shape | |||||
is (C, H, W), it will be transposed to (H, W, C). If the value | |||||
of the image is not normalized , it will be normalized between 0 to 1.""" | |||||
rgb = is_rgb(image) | |||||
chw = False | |||||
normalized = is_normalized(image) | |||||
if rgb: | |||||
chw = is_chw(image) | |||||
if chw: | |||||
image = chw_to_hwc(image) | |||||
else: | |||||
image = image | |||||
else: | |||||
image = image | |||||
if normalized: | |||||
image = np.uint8(image*255) | |||||
return rgb, chw, normalized, image | |||||
def _original_format(self, image, chw, normalized): | |||||
""" Return transformed image with original format. """ | |||||
if not is_numpy(image): | |||||
image = np.array(image) | |||||
if chw: | |||||
image = hwc_to_chw(image) | |||||
if normalized: | |||||
image = image / 255 | |||||
return image | |||||
def transform(self, image): | |||||
pass | |||||
class Contrast(ImageTransform): | |||||
""" | |||||
Contrast of an image. | |||||
Args: | |||||
factor ([float, int]): Control the contrast of an image. If 1.0 gives the | |||||
original image. If 0 gives a gray image. Default: 1. | |||||
""" | |||||
def __init__(self, factor=1): | |||||
super(Contrast, self).__init__() | |||||
self.set_params(factor) | |||||
def set_params(self, factor=1, auto_param=False): | |||||
""" | |||||
Set contrast parameters. | |||||
Args: | |||||
factor ([float, int]): Control the contrast of an image. If 1.0 gives | |||||
the original image. If 0 gives a gray image. Default: 1. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor = np.random.uniform(-5, 5) | |||||
else: | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
_, chw, normalized, image = self._check(image) | |||||
image = to_pil(image) | |||||
img_contrast = ImageEnhance.Contrast(image) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = self._original_format(trans_image, chw, normalized) | |||||
return trans_image | |||||
class Brightness(ImageTransform): | |||||
""" | |||||
Brightness of an image. | |||||
Args: | |||||
factor ([float, int]): Control the brightness of an image. If 1.0 gives | |||||
the original image. If 0 gives a black image. Default: 1. | |||||
""" | |||||
def __init__(self, factor=1): | |||||
super(Brightness, self).__init__() | |||||
self.set_params(factor) | |||||
def set_params(self, factor=1, auto_param=False): | |||||
""" | |||||
Set brightness parameters. | |||||
Args: | |||||
factor ([float, int]): Control the brightness of an image. If 1 | |||||
gives the original image. If 0 gives a black image. Default: 1. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0, 5) | |||||
else: | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
_, chw, normalized, image = self._check(image) | |||||
image = to_pil(image) | |||||
img_contrast = ImageEnhance.Brightness(image) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = self._original_format(trans_image, chw, normalized) | |||||
return trans_image | |||||
class Blur(ImageTransform): | |||||
""" | |||||
Blurs the image using Gaussian blur filter. | |||||
Args: | |||||
radius([float, int]): Blur radius, 0 means no blur. Default: 0. | |||||
""" | |||||
def __init__(self, radius=0): | |||||
super(Blur, self).__init__() | |||||
self.set_params(radius) | |||||
def set_params(self, radius=0, auto_param=False): | |||||
""" | |||||
Set blur parameters. | |||||
Args: | |||||
radius ([float, int]): Blur radius, 0 means no blur. Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.radius = np.random.uniform(-1.5, 1.5) | |||||
else: | |||||
self.radius = check_param_multi_types('radius', radius, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
_, chw, normalized, image = self._check(image) | |||||
image = to_pil(image) | |||||
trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||||
trans_image = self._original_format(trans_image, chw, normalized) | |||||
return trans_image | |||||
class Noise(ImageTransform): | |||||
""" | |||||
Add noise of an image. | |||||
Args: | |||||
factor (float): 1 - factor is the ratio of pixels to add noise. | |||||
If 0 gives the original image. Default 0. | |||||
""" | |||||
def __init__(self, factor=0): | |||||
super(Noise, self).__init__() | |||||
self.set_params(factor) | |||||
def set_params(self, factor=0, auto_param=False): | |||||
""" | |||||
Set noise parameters. | |||||
Args: | |||||
factor ([float, int]): 1 - factor is the ratio of pixels to add noise. | |||||
If 0 gives the original image. Default 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0.7, 1) | |||||
else: | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
_, chw, normalized, image = self._check(image) | |||||
noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) | |||||
trans_image = np.copy(image) | |||||
trans_image[noise < -self.factor] = 0 | |||||
trans_image[noise > self.factor] = 1 | |||||
trans_image = self._original_format(trans_image, chw, normalized) | |||||
return trans_image | |||||
class Translate(ImageTransform): | |||||
""" | |||||
Translate an image. | |||||
Args: | |||||
x_bias ([int, float): X-direction translation, x=x+x_bias. Default: 0. | |||||
y_bias ([int, float): Y-direction translation, y=y+y_bias. Default: 0. | |||||
""" | |||||
def __init__(self, x_bias=0, y_bias=0): | |||||
super(Translate, self).__init__() | |||||
self.set_params(x_bias, y_bias) | |||||
def set_params(self, x_bias=0, y_bias=0, auto_param=False): | |||||
""" | |||||
Set translate parameters. | |||||
Args: | |||||
x_bias ([float, int]): X-direction translation, x=x+x_bias. Default: 0. | |||||
y_bias ([float, int]): Y-direction translation, y=y+y_bias. Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
self.auto_param = auto_param | |||||
if auto_param: | |||||
self.x_bias = np.random.uniform(-0.3, 0.3) | |||||
self.y_bias = np.random.uniform(-0.3, 0.3) | |||||
else: | |||||
self.x_bias = check_param_multi_types('x_bias', x_bias, | |||||
[int, float]) | |||||
self.y_bias = check_param_multi_types('y_bias', y_bias, | |||||
[int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
_, chw, normalized, image = self._check(image) | |||||
img = to_pil(image) | |||||
if self.auto_param: | |||||
image_shape = np.shape(image) | |||||
self.x_bias = image_shape[0]*self.x_bias | |||||
self.y_bias = image_shape[1]*self.y_bias | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(1, 0, self.x_bias, 0, 1, self.y_bias)) | |||||
trans_image = self._original_format(trans_image, chw, normalized) | |||||
return trans_image | |||||
class Scale(ImageTransform): | |||||
""" | |||||
Scale an image in the middle. | |||||
Args: | |||||
factor_x ([float, int]): Rescale in X-direction, x=factor_x*x. Default: 1. | |||||
factor_y ([float, int]): Rescale in Y-direction, y=factor_y*y. Default: 1. | |||||
""" | |||||
def __init__(self, factor_x=1, factor_y=1): | |||||
super(Scale, self).__init__() | |||||
self.set_params(factor_x, factor_y) | |||||
def set_params(self, factor_x=1, factor_y=1, auto_param=False): | |||||
""" | |||||
Set scale parameters. | |||||
Args: | |||||
factor_x ([float, int]): Rescale in X-direction, x=factor_x*x. | |||||
Default: 1. | |||||
factor_y ([float, int]): Rescale in Y-direction, y=factor_y*y. | |||||
Default: 1. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor_x = np.random.uniform(0.7, 3) | |||||
self.factor_y = np.random.uniform(0.7, 3) | |||||
else: | |||||
self.factor_x = check_param_multi_types('factor_x', factor_x, | |||||
[int, float]) | |||||
self.factor_y = check_param_multi_types('factor_y', factor_y, | |||||
[int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
rgb, chw, normalized, image = self._check(image) | |||||
if rgb: | |||||
h, w, _ = np.shape(image) | |||||
else: | |||||
h, w = np.shape(image) | |||||
move_x_centor = w / 2*(1 - self.factor_x) | |||||
move_y_centor = h / 2*(1 - self.factor_y) | |||||
img = to_pil(image) | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(self.factor_x, 0, move_x_centor, | |||||
0, self.factor_y, move_y_centor)) | |||||
trans_image = self._original_format(trans_image, chw, normalized) | |||||
return trans_image | |||||
class Shear(ImageTransform): | |||||
""" | |||||
Shear an image, for each pixel (x, y) in the sheared image, the new value is | |||||
taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then | |||||
the sheared image will be rescaled to fit original size. | |||||
Args: | |||||
factor_x ([float, int]): Shear factor of horizontal direction. Default: 0. | |||||
factor_y ([float, int]): Shear factor of vertical direction. Default: 0. | |||||
""" | |||||
def __init__(self, factor_x=0, factor_y=0): | |||||
super(Shear, self).__init__() | |||||
self.set_params(factor_x, factor_y) | |||||
def set_params(self, factor_x=0, factor_y=0, auto_param=False): | |||||
""" | |||||
Set shear parameters. | |||||
Args: | |||||
factor_x ([float, int]): Shear factor of horizontal direction. | |||||
Default: 0. | |||||
factor_y ([float, int]): Shear factor of vertical direction. | |||||
Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if factor_x != 0 and factor_y != 0: | |||||
msg = 'factor_x and factor_y can not be both more than 0.' | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if auto_param: | |||||
if np.random.uniform(-1, 1) > 0: | |||||
self.factor_x = np.random.uniform(-2, 2) | |||||
self.factor_y = 0 | |||||
else: | |||||
self.factor_x = 0 | |||||
self.factor_y = np.random.uniform(-2, 2) | |||||
else: | |||||
self.factor_x = check_param_multi_types('factor', factor_x, | |||||
[int, float]) | |||||
self.factor_y = check_param_multi_types('factor', factor_y, | |||||
[int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
rgb, chw, normalized, image = self._check(image) | |||||
img = to_pil(image) | |||||
if rgb: | |||||
h, w, _ = np.shape(image) | |||||
else: | |||||
h, w = np.shape(image) | |||||
if self.factor_x != 0: | |||||
boarder_x = [0, -w, -self.factor_x*h, -w - self.factor_x*h] | |||||
min_x = min(boarder_x) | |||||
max_x = max(boarder_x) | |||||
scale = (max_x - min_x) / w | |||||
move_x_cen = (w - scale*w - scale*h*self.factor_x) / 2 | |||||
move_y_cen = h*(1 - scale) / 2 | |||||
else: | |||||
boarder_y = [0, -h, -self.factor_y*w, -h - self.factor_y*w] | |||||
min_y = min(boarder_y) | |||||
max_y = max(boarder_y) | |||||
scale = (max_y - min_y) / h | |||||
move_y_cen = (h - scale*h - scale*w*self.factor_y) / 2 | |||||
move_x_cen = w*(1 - scale) / 2 | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(scale, scale*self.factor_x, move_x_cen, | |||||
scale*self.factor_y, scale, move_y_cen)) | |||||
trans_image = self._original_format(trans_image, chw, normalized) | |||||
return trans_image | |||||
class Rotate(ImageTransform): | |||||
""" | |||||
Rotate an image of degrees counter clockwise around its center. | |||||
Args: | |||||
angle([float, int]): Degrees counter clockwise. Default: 0. | |||||
""" | |||||
def __init__(self, angle=0): | |||||
super(Rotate, self).__init__() | |||||
self.set_params(angle) | |||||
def set_params(self, angle=0, auto_param=False): | |||||
""" | |||||
Set rotate parameters. | |||||
Args: | |||||
angle([float, int]): Degrees counter clockwise. Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.angle = np.random.uniform(0, 360) | |||||
else: | |||||
self.angle = check_param_multi_types('angle', angle, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
_, chw, normalized, image = self._check(image) | |||||
img = to_pil(image) | |||||
trans_image = img.rotate(self.angle, expand=True) | |||||
trans_image = self._original_format(trans_image, chw, normalized) | |||||
return trans_image |
@@ -133,8 +133,7 @@ class ModelCoverageMetrics: | |||||
else: | else: | ||||
self._main_section_hits[i][int(section_indexes[i])] = 1 | self._main_section_hits[i][int(section_indexes[i])] = 1 | ||||
def test_adequacy_coverage_calculate(self, dataset, bias_coefficient=0, | |||||
batch_size=32): | |||||
def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): | |||||
""" | """ | ||||
Calculate the testing adequacy of the given dataset. | Calculate the testing adequacy of the given dataset. | ||||
@@ -147,7 +146,7 @@ class ModelCoverageMetrics: | |||||
Examples: | Examples: | ||||
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | >>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | ||||
>>> model_fuzz_test.test_adequacy_coverage_calculate(test_images) | |||||
>>> model_fuzz_test.calculate_coverage(test_images) | |||||
""" | """ | ||||
dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
batch_size = check_int_positive('batch_size', batch_size) | batch_size = check_int_positive('batch_size', batch_size) | ||||
@@ -1,267 +0,0 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image transform | |||||
""" | |||||
import numpy as np | |||||
from PIL import Image, ImageEnhance, ImageFilter | |||||
import random | |||||
from mindarmour.utils._check_param import check_numpy_param | |||||
class ImageTransform: | |||||
""" | |||||
The abstract base class for all image transform classes. | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
def random_param(self): | |||||
pass | |||||
def transform(self): | |||||
pass | |||||
class Contrast(ImageTransform): | |||||
""" | |||||
Contrast of an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Contrast, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.factor = random.uniform(-5, 5) | |||||
def transform(self): | |||||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||||
img_contrast = ImageEnhance.Contrast(img) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = np.array(trans_image)/255 | |||||
return trans_image | |||||
class Brightness(ImageTransform): | |||||
""" | |||||
Brightness of an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Brightness, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.factor = random.uniform(0, 5) | |||||
def transform(self): | |||||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||||
img_contrast = ImageEnhance.Brightness(img) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = np.array(trans_image)/255 | |||||
return trans_image | |||||
class Blur(ImageTransform): | |||||
""" | |||||
GaussianBlur of an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Blur, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.radius = random.uniform(-1.5, 1.5) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||||
trans_image = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||||
trans_image = np.array(trans_image)/255 | |||||
return trans_image | |||||
class Noise(ImageTransform): | |||||
""" | |||||
Add noise of an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Noise, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" random generate parameters """ | |||||
self.factor = random.uniform(0.7, 1) | |||||
def transform(self): | |||||
""" Random generate parameters. """ | |||||
noise = np.random.uniform(low=-1, high=1, size=self.image.shape) | |||||
trans_image = np.copy(self.image) | |||||
trans_image[noise < -self.factor] = 0 | |||||
trans_image[noise > self.factor] = 1 | |||||
trans_image = np.array(trans_image) | |||||
return trans_image | |||||
class Translate(ImageTransform): | |||||
""" | |||||
Translate an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Translate, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
image_shape = np.shape(self.image) | |||||
self.x_bias = random.uniform(-image_shape[0]/3, image_shape[0]/3) | |||||
self.y_bias = random.uniform(-image_shape[1]/3, image_shape[1]/3) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(1, 0, self.x_bias, 0, 1, self.y_bias)) | |||||
trans_image = np.array(trans_image)/255 | |||||
return trans_image | |||||
class Scale(ImageTransform): | |||||
""" | |||||
Scale an image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
mode(str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Scale, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.factor_x = random.uniform(0.7, 2) | |||||
self.factor_y = random.uniform(0.7, 2) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(self.factor_x, 0, 0, 0, self.factor_y, 0)) | |||||
trans_image = np.array(trans_image)/255 | |||||
return trans_image | |||||
class Shear(ImageTransform): | |||||
""" | |||||
Shear an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Shear, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.factor = random.uniform(0, 1) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||||
if np.random.random() > 0.5: | |||||
level = -self.factor | |||||
else: | |||||
level = self.factor | |||||
if np.random.random() > 0.5: | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(1, level, 0, 0, 1, 0)) | |||||
else: | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(1, 0, 0, level, 1, 0)) | |||||
trans_image = np.array(trans_image, dtype=np.float)/255 | |||||
return trans_image | |||||
class Rotate(ImageTransform): | |||||
""" | |||||
Rotate an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Rotate, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.angle = random.uniform(0, 360) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(np.uint8(self.image*255), self.mode) | |||||
trans_image = img.rotate(self.angle) | |||||
trans_image = np.array(trans_image)/255 | |||||
return trans_image |
@@ -77,7 +77,7 @@ def test_lenet_mnist_coverage_cpu(): | |||||
# get test data | # get test data | ||||
test_data = (np.random.random((2000, 10))*20).astype(np.float32) | test_data = (np.random.random((2000, 10))*20).astype(np.float32) | ||||
test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | ||||
model_fuzz_test.test_adequacy_coverage_calculate(test_data) | |||||
model_fuzz_test.calculate_coverage(test_data) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
@@ -86,8 +86,7 @@ def test_lenet_mnist_coverage_cpu(): | |||||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | ||||
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) | ||||
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | ||||
model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||||
bias_coefficient=0.5) | |||||
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
@@ -113,7 +112,7 @@ def test_lenet_mnist_coverage_ascend(): | |||||
test_data = (np.random.random((2000, 10))*20).astype(np.float32) | test_data = (np.random.random((2000, 10))*20).astype(np.float32) | ||||
test_labels = np.random.randint(0, 10, 2000) | test_labels = np.random.randint(0, 10, 2000) | ||||
test_labels = (np.eye(10)[test_labels]).astype(np.float32) | test_labels = (np.eye(10)[test_labels]).astype(np.float32) | ||||
model_fuzz_test.test_adequacy_coverage_calculate(test_data) | |||||
model_fuzz_test.calculate_coverage(test_data) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
@@ -121,8 +120,7 @@ def test_lenet_mnist_coverage_ascend(): | |||||
# generate adv_data | # generate adv_data | ||||
attack = FastGradientSignMethod(net, eps=0.3) | attack = FastGradientSignMethod(net, eps=0.3) | ||||
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | ||||
model_fuzz_test.test_adequacy_coverage_calculate(adv_data, | |||||
bias_coefficient=0.5) | |||||
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) |
@@ -1,160 +0,0 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Model-fuzz coverage test. | |||||
""" | |||||
import numpy as np | |||||
import pytest | |||||
from mindspore import context | |||||
from mindspore import nn | |||||
from mindspore.common.initializer import TruncatedNormal | |||||
from mindspore.ops import operations as P | |||||
from mindspore.train import Model | |||||
from mindarmour.fuzzing.fuzzing import Fuzzing | |||||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Fuzzing test' | |||||
LOGGER.set_level('INFO') | |||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
weight = weight_variable() | |||||
return nn.Conv2d(in_channels, out_channels, | |||||
kernel_size=kernel_size, stride=stride, padding=padding, | |||||
weight_init=weight, has_bias=False, pad_mode="valid") | |||||
def fc_with_initialize(input_channels, out_channels): | |||||
weight = weight_variable() | |||||
bias = weight_variable() | |||||
return nn.Dense(input_channels, out_channels, weight, bias) | |||||
def weight_variable(): | |||||
return TruncatedNormal(0.02) | |||||
class Net(nn.Cell): | |||||
""" | |||||
Lenet network | |||||
""" | |||||
def __init__(self): | |||||
super(Net, self).__init__() | |||||
self.conv1 = conv(1, 6, 5) | |||||
self.conv2 = conv(6, 16, 5) | |||||
self.fc1 = fc_with_initialize(16*5*5, 120) | |||||
self.fc2 = fc_with_initialize(120, 84) | |||||
self.fc3 = fc_with_initialize(84, 10) | |||||
self.relu = nn.ReLU() | |||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
self.reshape = P.Reshape() | |||||
def construct(self, x): | |||||
x = self.conv1(x) | |||||
x = self.relu(x) | |||||
x = self.max_pool2d(x) | |||||
x = self.conv2(x) | |||||
x = self.relu(x) | |||||
x = self.max_pool2d(x) | |||||
x = self.reshape(x, (-1, 16*5*5)) | |||||
x = self.fc1(x) | |||||
x = self.relu(x) | |||||
x = self.fc2(x) | |||||
x = self.relu(x) | |||||
x = self.fc3(x) | |||||
return x | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_fuzzing_ascend(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
# load network | |||||
net = Net() | |||||
model = Model(net) | |||||
batch_size = 8 | |||||
num_classe = 10 | |||||
# initialize fuzz test with training dataset | |||||
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||||
# fuzz test with original test data | |||||
# get test data | |||||
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||||
initial_seeds = [] | |||||
for img, label in zip(test_data, test_labels): | |||||
initial_seeds.append([img, label, 0]) | |||||
model_coverage_test.test_adequacy_coverage_calculate( | |||||
np.array(test_data).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||||
model_coverage_test.get_kmnc()) | |||||
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||||
max_seed_num=10) | |||||
failed_tests = model_fuzz_test.fuzzing() | |||||
if failed_tests: | |||||
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||||
else: | |||||
LOGGER.info(TAG, 'Fuzzing test identifies none failed test') | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_fuzzing_CPU(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
# load network | |||||
net = Net() | |||||
model = Model(net) | |||||
batch_size = 8 | |||||
num_classe = 10 | |||||
# initialize fuzz test with training dataset | |||||
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||||
# fuzz test with original test data | |||||
# get test data | |||||
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||||
initial_seeds = [] | |||||
for img, label in zip(test_data, test_labels): | |||||
initial_seeds.append([img, label, 0]) | |||||
model_coverage_test.test_adequacy_coverage_calculate( | |||||
np.array(test_data).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||||
model_coverage_test.get_kmnc()) | |||||
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||||
max_seed_num=10) | |||||
failed_tests = model_fuzz_test.fuzzing() | |||||
if failed_tests: | |||||
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||||
else: | |||||
LOGGER.info(TAG, 'Fuzzing test identifies none failed test') |
@@ -18,7 +18,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||||
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \ | |||||
Translate, Scale, Shear, Rotate | Translate, Scale, Shear, Rotate | ||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
@@ -31,11 +31,10 @@ LOGGER.set_level('INFO') | |||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_contrast(): | def test_contrast(): | ||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Contrast(image, mode) | |||||
trans.random_param() | |||||
_ = trans.transform() | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Contrast() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@@ -43,11 +42,10 @@ def test_contrast(): | |||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_brightness(): | def test_brightness(): | ||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Brightness(image, mode) | |||||
trans.random_param() | |||||
_ = trans.transform() | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Brightness() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@@ -57,11 +55,10 @@ def test_brightness(): | |||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_blur(): | def test_blur(): | ||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Blur(image, mode) | |||||
trans.random_param() | |||||
_ = trans.transform() | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Blur() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@@ -71,11 +68,10 @@ def test_blur(): | |||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_noise(): | def test_noise(): | ||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Noise(image, mode) | |||||
trans.random_param() | |||||
_ = trans.transform() | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Noise() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@@ -85,11 +81,10 @@ def test_noise(): | |||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_translate(): | def test_translate(): | ||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Translate(image, mode) | |||||
trans.random_param() | |||||
_ = trans.transform() | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Translate() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@@ -99,11 +94,10 @@ def test_translate(): | |||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_shear(): | def test_shear(): | ||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Shear(image, mode) | |||||
trans.random_param() | |||||
_ = trans.transform() | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Shear() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@@ -113,11 +107,10 @@ def test_shear(): | |||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_scale(): | def test_scale(): | ||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Scale(image, mode) | |||||
trans.random_param() | |||||
_ = trans.transform() | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Scale() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@@ -127,8 +120,7 @@ def test_scale(): | |||||
@pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_rotate(): | def test_rotate(): | ||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Rotate(image, mode) | |||||
trans.random_param() | |||||
_ = trans.transform() | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Rotate() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) |