@@ -19,7 +19,6 @@ example/mnist_demo/model/ | |||||
example/cifar_demo/model/ | example/cifar_demo/model/ | ||||
example/dog_cat_demo/model/ | example/dog_cat_demo/model/ | ||||
mindarmour.egg-info/ | mindarmour.egg-info/ | ||||
*model/ | |||||
*MNIST/ | *MNIST/ | ||||
*out.data/ | *out.data/ | ||||
*defensed_model/ | *defensed_model/ | ||||
@@ -27,7 +27,7 @@ from mindspore.nn.optim.momentum import Momentum | |||||
from mindarmour.adv_robustness.defenses import AdversarialDefense | from mindarmour.adv_robustness.defenses import AdversarialDefense | ||||
from mindarmour.fuzz_testing import Fuzzer | from mindarmour.fuzz_testing import Fuzzer | ||||
from mindarmour.fuzz_testing import ModelCoverageMetrics | |||||
from mindarmour.fuzz_testing import KMultisectionNeuronCoverage | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from examples.common.dataset.data_processing import generate_mnist_dataset | from examples.common.dataset.data_processing import generate_mnist_dataset | ||||
@@ -38,33 +38,66 @@ TAG = 'Fuzz_testing and enhance model' | |||||
LOGGER.set_level('INFO') | LOGGER.set_level('INFO') | ||||
def split_dataset(image, label, proportion): | |||||
""" | |||||
Split the generated fuzz data into train and test set. | |||||
""" | |||||
indices = np.arange(len(image)) | |||||
random.shuffle(indices) | |||||
train_length = int(len(image) * proportion) | |||||
train_image = [image[i] for i in indices[:train_length]] | |||||
train_label = [label[i] for i in indices[:train_length]] | |||||
test_image = [image[i] for i in indices[:train_length]] | |||||
test_label = [label[i] for i in indices[:train_length]] | |||||
return train_image, train_label, test_image, test_label | |||||
def example_lenet_mnist_fuzzing(): | def example_lenet_mnist_fuzzing(): | ||||
""" | """ | ||||
An example of fuzz testing and then enhance the non-robustness model. | An example of fuzz testing and then enhance the non-robustness model. | ||||
""" | """ | ||||
# upload trained network | # upload trained network | ||||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m1-10_1250.ckpt' | |||||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = LeNet5() | net = LeNet5() | ||||
load_dict = load_checkpoint(ckpt_path) | load_dict = load_checkpoint(ckpt_path) | ||||
load_param_into_net(net, load_dict) | load_param_into_net(net, load_dict) | ||||
model = Model(net) | model = Model(net) | ||||
mutate_config = [{'method': 'Blur', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Translate', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Brightness', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Noise', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Scale', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Shear', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}} | |||||
] | |||||
mutate_config = [ | |||||
{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, | |||||
{'method': 'MotionBlur', | |||||
'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}}, | |||||
{'method': 'GradientBlur', | |||||
'params': {'point': [[10, 10]], 'auto_param': [True]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'GaussianNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'SaltAndPepperNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'NaturalNoise', | |||||
'params': {'ratio': [0.1], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)], 'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'GradientLuminance', | |||||
'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)], | |||||
'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'], | |||||
'auto_param': [False, True]}}, | |||||
{'method': 'Translate', | |||||
'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}}, | |||||
{'method': 'Scale', | |||||
'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}}, | |||||
{'method': 'Shear', | |||||
'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'Perspective', | |||||
'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]], | |||||
'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}}, | |||||
{'method': 'Curve', | |||||
'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
# get training data | # get training data | ||||
data_list = "../common/dataset/MNIST/train" | data_list = "../common/dataset/MNIST/train" | ||||
@@ -75,49 +108,36 @@ def example_lenet_mnist_fuzzing(): | |||||
images = data[0].astype(np.float32) | images = data[0].astype(np.float32) | ||||
train_images.append(images) | train_images.append(images) | ||||
train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
neuron_num = 10 | |||||
segmented_num = 1000 | |||||
# initialize fuzz test with training dataset | |||||
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||||
segmented_num = 100 | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | |||||
data_list = "../common/dataset/MNIST/test" | data_list = "../common/dataset/MNIST/test" | ||||
batch_size = 32 | |||||
init_samples = 5000 | |||||
max_iters = 50000 | |||||
batch_size = batch_size | |||||
init_samples = 50 | |||||
max_iters = 500 | |||||
mutate_num_per_seed = 10 | mutate_num_per_seed = 10 | ||||
ds = generate_mnist_dataset(data_list, batch_size, num_samples=init_samples, | |||||
sparse=False) | |||||
ds = generate_mnist_dataset(data_list, batch_size=batch_size, num_samples=init_samples, sparse=False) | |||||
test_images = [] | test_images = [] | ||||
test_labels = [] | test_labels = [] | ||||
for data in ds.create_tuple_iterator(output_numpy=True): | for data in ds.create_tuple_iterator(output_numpy=True): | ||||
images = data[0].astype(np.float32) | |||||
labels = data[1] | |||||
test_images.append(images) | |||||
test_labels.append(labels) | |||||
test_images.append(data[0].astype(np.float32)) | |||||
test_labels.append(data[1]) | |||||
test_images = np.concatenate(test_images, axis=0) | test_images = np.concatenate(test_images, axis=0) | ||||
test_labels = np.concatenate(test_labels, axis=0) | test_labels = np.concatenate(test_labels, axis=0) | ||||
initial_seeds = [] | |||||
coverage = KMultisectionNeuronCoverage(model, train_images, segmented_num=segmented_num, incremental=True) | |||||
kmnc = coverage.get_metrics(test_images[:100]) | |||||
print('kmnc: ', kmnc) | |||||
# make initial seeds | # make initial seeds | ||||
initial_seeds = [] | |||||
for img, label in zip(test_images, test_labels): | for img, label in zip(test_images, test_labels): | ||||
initial_seeds.append([img, label]) | initial_seeds.append([img, label]) | ||||
model_coverage_test.calculate_coverage( | |||||
np.array(test_images[:100]).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of test dataset before fuzzing is : %s', | |||||
model_coverage_test.get_kmnc()) | |||||
LOGGER.info(TAG, 'NBC of test dataset before fuzzing is : %s', | |||||
model_coverage_test.get_nbc()) | |||||
LOGGER.info(TAG, 'SNAC of test dataset before fuzzing is : %s', | |||||
model_coverage_test.get_snac()) | |||||
model_fuzz_test = Fuzzer(model, train_images, 10, 1000) | |||||
model_fuzz_test = Fuzzer(model) | |||||
gen_samples, gt, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, | gen_samples, gt, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, | ||||
initial_seeds, | |||||
eval_metrics='auto', | |||||
initial_seeds, coverage, | |||||
evaluate=True, | |||||
max_iters=max_iters, | max_iters=max_iters, | ||||
mutate_num_per_seed=mutate_num_per_seed) | mutate_num_per_seed=mutate_num_per_seed) | ||||
@@ -125,24 +145,10 @@ def example_lenet_mnist_fuzzing(): | |||||
for key in metrics: | for key in metrics: | ||||
LOGGER.info(TAG, key + ': %s', metrics[key]) | LOGGER.info(TAG, key + ': %s', metrics[key]) | ||||
def split_dataset(image, label, proportion): | |||||
""" | |||||
Split the generated fuzz data into train and test set. | |||||
""" | |||||
indices = np.arange(len(image)) | |||||
random.shuffle(indices) | |||||
train_length = int(len(image) * proportion) | |||||
train_image = [image[i] for i in indices[:train_length]] | |||||
train_label = [label[i] for i in indices[:train_length]] | |||||
test_image = [image[i] for i in indices[:train_length]] | |||||
test_label = [label[i] for i in indices[:train_length]] | |||||
return train_image, train_label, test_image, test_label | |||||
train_image, train_label, test_image, test_label = split_dataset( | |||||
gen_samples, gt, 0.7) | |||||
train_image, train_label, test_image, test_label = split_dataset(gen_samples, gt, 0.7) | |||||
# load model B and test it on the test set | # load model B and test it on the test set | ||||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m2-10_1250.ckpt' | |||||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = LeNet5() | net = LeNet5() | ||||
load_dict = load_checkpoint(ckpt_path) | load_dict = load_checkpoint(ckpt_path) | ||||
load_param_into_net(net, load_dict) | load_param_into_net(net, load_dict) | ||||
@@ -154,12 +160,11 @@ def example_lenet_mnist_fuzzing(): | |||||
# enhense model robustness | # enhense model robustness | ||||
lr = 0.001 | lr = 0.001 | ||||
momentum = 0.9 | momentum = 0.9 | ||||
loss_fn = SoftmaxCrossEntropyWithLogits(Sparse=True) | |||||
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=True) | |||||
optimizer = Momentum(net.trainable_params(), lr, momentum) | optimizer = Momentum(net.trainable_params(), lr, momentum) | ||||
adv_defense = AdversarialDefense(net, loss_fn, optimizer) | adv_defense = AdversarialDefense(net, loss_fn, optimizer) | ||||
adv_defense.batch_defense(np.array(train_image).astype(np.float32), | |||||
np.argmax(train_label, axis=1).astype(np.int32)) | |||||
adv_defense.batch_defense(np.array(train_image).astype(np.float32), np.argmax(train_label, axis=1).astype(np.int32)) | |||||
preds_en = net(Tensor(test_image, dtype=mindspore.float32)).asnumpy() | preds_en = net(Tensor(test_image, dtype=mindspore.float32)).asnumpy() | ||||
acc_en = np.sum(np.argmax(preds_en, axis=1) == np.argmax(test_label, axis=1)) / len(test_label) | acc_en = np.sum(np.argmax(preds_en, axis=1) == np.argmax(test_label, axis=1)) / len(test_label) | ||||
print('Accuracy of enhensed model on test set is ', acc_en) | print('Accuracy of enhensed model on test set is ', acc_en) | ||||
@@ -167,5 +172,5 @@ def example_lenet_mnist_fuzzing(): | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
# device_target can be "CPU", "GPU" or "Ascend" | # device_target can be "CPU", "GPU" or "Ascend" | ||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
example_lenet_mnist_fuzzing() | example_lenet_mnist_fuzzing() |
@@ -35,24 +35,50 @@ def test_lenet_mnist_fuzzing(): | |||||
load_dict = load_checkpoint(ckpt_path) | load_dict = load_checkpoint(ckpt_path) | ||||
load_param_into_net(net, load_dict) | load_param_into_net(net, load_dict) | ||||
model = Model(net) | model = Model(net) | ||||
mutate_config = [{'method': 'Blur', | |||||
'params': {'radius': [0.1, 0.2, 0.3], | |||||
'auto_param': [True, False]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Translate', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Brightness', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Noise', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Scale', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Shear', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}} | |||||
] | |||||
mutate_config = [ | |||||
{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], | |||||
'auto_param': [True, False]}}, | |||||
{'method': 'MotionBlur', | |||||
'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}}, | |||||
{'method': 'GradientBlur', | |||||
'params': {'point': [[10, 10]], 'auto_param': [True]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'GaussianNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'SaltAndPepperNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'NaturalNoise', | |||||
'params': {'ratio': [0.1, 0.2, 0.3], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)], | |||||
'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'GradientLuminance', | |||||
'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)], | |||||
'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'], | |||||
'auto_param': [False, True]}}, | |||||
{'method': 'Translate', | |||||
'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}}, | |||||
{'method': 'Scale', | |||||
'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}}, | |||||
{'method': 'Shear', | |||||
'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'Perspective', | |||||
'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]], | |||||
'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}}, | |||||
{'method': 'Curve', | |||||
'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}, | |||||
{'method': 'PGD', | |||||
'params': {'eps': [0.1, 0.2, 0.4], 'eps_iter': [0.05, 0.1], 'nb_iter': [1, 3]}}, | |||||
{'method': 'MDIIM', | |||||
'params': {'eps': [0.1, 0.2, 0.4], 'prob': [0.5, 0.1], | |||||
'norm_level': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', 'linf']}} | |||||
] | |||||
# get training data | # get training data | ||||
data_list = "../common/dataset/MNIST/train" | data_list = "../common/dataset/MNIST/train" | ||||
@@ -88,7 +114,10 @@ def test_lenet_mnist_fuzzing(): | |||||
print('KMNC of initial seeds is: ', kmnc) | print('KMNC of initial seeds is: ', kmnc) | ||||
initial_seeds = initial_seeds[:100] | initial_seeds = initial_seeds[:100] | ||||
model_fuzz_test = Fuzzer(model) | model_fuzz_test = Fuzzer(model) | ||||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10, | |||||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, | |||||
initial_seeds, coverage, | |||||
evaluate=True, | |||||
max_iters=10, | |||||
mutate_num_per_seed=20) | mutate_num_per_seed=20) | ||||
if metrics: | if metrics: | ||||
@@ -0,0 +1,176 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
"""Example for natural robustness methods.""" | |||||
import numpy as np | |||||
import cv2 | |||||
from mindarmour.natural_robustness import Translate, Curve, Perspective, Scale, Shear, Rotate, SaltAndPepperNoise, \ | |||||
NaturalNoise, GaussianNoise, UniformNoise, MotionBlur, GaussianBlur, GradientBlur, Contrast, GradientLuminance | |||||
def test_perspective(image): | |||||
"""Test perspective.""" | |||||
ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]] | |||||
dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]] | |||||
trans = Perspective(ori_pos, dst_pos) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_uniform_noise(image): | |||||
"""Test uniform noise.""" | |||||
trans = UniformNoise(factor=0.1) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_gaussian_noise(image): | |||||
"""Test gaussian noise.""" | |||||
trans = GaussianNoise(factor=0.1) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_contrast(image): | |||||
"""Test contrast.""" | |||||
trans = Contrast(alpha=0.3, beta=0) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_gaussian_blur(image): | |||||
"""Test gaussian blur.""" | |||||
trans = GaussianBlur(ksize=5) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_salt_and_pepper_noise(image): | |||||
"""Test salt and pepper noise.""" | |||||
trans = SaltAndPepperNoise(factor=0.01) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_translate(image): | |||||
"""Test translate.""" | |||||
trans = Translate(x_bias=0.1, y_bias=0.1) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_scale(image): | |||||
"""Test scale.""" | |||||
trans = Scale(factor_x=0.7, factor_y=0.7) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_shear(image): | |||||
"""Test shear.""" | |||||
trans = Shear(factor=0.2) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_rotate(image): | |||||
"""Test rotate.""" | |||||
trans = Rotate(angle=20) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_curve(image): | |||||
"""Test curve.""" | |||||
trans = Curve(curves=1.5, depth=1.5, mode='horizontal') | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_natural_noise(image): | |||||
"""Test natural noise.""" | |||||
trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_gradient_luminance(image): | |||||
"""Test gradient luminance.""" | |||||
height, width = image.shape[:2] | |||||
point = (height // 4, width // 2) | |||||
start = (255, 255, 255) | |||||
end = (0, 0, 0) | |||||
scope = 0.3 | |||||
bright_rate = 0.4 | |||||
trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate, | |||||
mode='horizontal') | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_motion_blur(image): | |||||
"""Test motion blur.""" | |||||
angle = -10.5 | |||||
i = 3 | |||||
trans = MotionBlur(degree=i, angle=angle) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
def test_gradient_blur(image): | |||||
"""Test gradient blur.""" | |||||
number = 10 | |||||
h, w = image.shape[:2] | |||||
point = (int(h / 5), int(w / 5)) | |||||
center = False | |||||
trans = GradientBlur(point, number, center) | |||||
dst = trans(image) | |||||
cv2.imshow('dst', dst) | |||||
cv2.waitKey() | |||||
if __name__ == '__main__': | |||||
img = cv2.imread('1.jpeg') | |||||
img = np.array(img) | |||||
test_uniform_noise(img) | |||||
test_gaussian_noise(img) | |||||
test_motion_blur(img) | |||||
test_gradient_blur(img) | |||||
test_gradient_luminance(img) | |||||
test_natural_noise(img) | |||||
test_curve(img) | |||||
test_rotate(img) | |||||
test_shear(img) | |||||
test_scale(img) | |||||
test_translate(img) | |||||
test_salt_and_pepper_noise(img) | |||||
test_gaussian_blur(img) | |||||
test_constract(img) | |||||
test_perspective(img) |
@@ -0,0 +1,206 @@ | |||||
# 自然扰动样本生成serving | |||||
提供自然扰动样本生成在线服务。客户端传入图片和扰动参数,服务端返回扰动后的图片数据。 | |||||
## 环境准备 | |||||
硬件环境:Ascend 910,GPU | |||||
操作系统:Linux-x86_64 | |||||
软件环境: | |||||
1. python 3.7.5或python 3.9.0 | |||||
2. 安装MindSpore 1.5.0可以参考[MindSpore安装页面](https://www.mindspore.cn/install) | |||||
3. 安装MindSpore Serving 1.5.0可以参考[MindSpore Serving 安装页面](https://www.mindspore.cn/serving/docs/zh-CN/r1.5/serving_install.html) | |||||
4. 安装serving分支的MindArmour: | |||||
- 从Gitee下载源码 | |||||
`git clone https://gitee.com/mindspore/mindarmour.git` | |||||
- 编译并安装MindArmour | |||||
`python setup.py install` | |||||
### 文件结构说明 | |||||
```bash | |||||
serving | |||||
├── server | |||||
│ ├── serving_server.py # 启动serving服务脚本 | |||||
│ ├── export_model | |||||
│ │ └── add_model.py # 生成模型文件脚本 | |||||
│ └── perturbation | |||||
│ └── serverable_config.py # 服务端接收客户端数据后的处理脚本 | |||||
└── client | |||||
├── serving_client.py # 启动客户端脚本 | |||||
└── perturb_config.py # 扰动方法配置文件 | |||||
``` | |||||
## 脚本说明及使用 | |||||
### 导出模型 | |||||
在`server/export_model`目录下,使用[add_model.py](https://gitee.com/mindspore/serving/blob/r1.5/example/tensor_add/export_model/add_model.py),构造了一个只有Add算子的tensor加法网络。使用命令 | |||||
```bash | |||||
python add_model.py | |||||
``` | |||||
在`perturbation`模型文件夹下生成`tensor_add.mindir`模型文件。 | |||||
该服务实际上并没有使用到模型,但目前版本的serving需要有一个模型,serving升级后这部分会删除。 | |||||
### 部署Serving推理服务 | |||||
1. #### `servable_config.py`说明。 | |||||
```python | |||||
··· | |||||
# 客户端可以请求的方法,包含4个返回值:"results", "file_names", "file_length", "names_dict" | |||||
@register.register_method(output_names=["results", "file_names", "file_length", "names_dict"]) | |||||
def natural_perturbation(img, perturb_config, methods_number, outputs_number): | |||||
"""method natural_perturbation data flow definition, only preprocessing and call model""" | |||||
res = register.add_stage(perturb, img, perturb_config, methods_number, outputs_number, outputs_count=4) | |||||
return res | |||||
``` | |||||
方法`natural_perturbation`为对外提供服务的接口。 | |||||
**输入:** | |||||
- img:输入为图片,格式为bytes。 | |||||
- perturb_config:扰动配置项,具体配置参考`perturb_config.py`。 | |||||
- methods_number:每次扰动随机从配置项中选择方法的个数。 | |||||
- outputs_number:对于每张图片,生成的扰动图片数量。 | |||||
**输出**res中包含4个参数: | |||||
- results:拼接后的图像bytes; | |||||
- file_names:图像名,格式为`xxx.png`,其中‘xxx’为A-Za-z中随机选择20个字符构成的字符串。 | |||||
- file_length:每张图片的bytes长度。 | |||||
- names_dict: 图片名和图片使用扰动方法构成的字典。格式为: | |||||
```bash | |||||
{ | |||||
picture1.png: [[method1, parameters of method1], [method2, parameters of method2], ...]], | |||||
picture2.png: [[method3, parameters of method3], [method4, parameters of method4], ...]], | |||||
... | |||||
} | |||||
``` | |||||
2. #### 启动server。 | |||||
```python | |||||
··· | |||||
def start(): | |||||
servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) | |||||
# 服务配置 | |||||
servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="perturbation", device_ids=(0, 1), num_parallel_workers=4) | |||||
# 启动服务 | |||||
server.start_servables(servable_configs=servable_config) | |||||
# 启动启动gRPC服务,用于客户端和服务端之间通信 | |||||
server.start_grpc_server(address="0.0.0.0:5500", max_msg_mb_size=200) # ip和最大的传输数据量,单位MB | |||||
# 启动启动Restful服务,用于客户端和服务端之间通信 | |||||
server.start_restful_server(address="0.0.0.0:5500") | |||||
``` | |||||
gRPC传输性能更好,Restful更适合用于web服务,根据需要选择。 | |||||
执行命令`python serverong_server.py`启动服务。 | |||||
当服务端打印日志`Serving RESTful server start success, listening on 0.0.0.0:5500`时,表示Serving RESTful服务启动成功,推理模型已成功加载。 | |||||
### 客户端进行推理 | |||||
1. 在`perturb_config.py`中设置扰动方法及参数。下面是个例子: | |||||
```python | |||||
PerturbConfig = [{"method": "Contrast", "params": {"alpha": 1.5, "beta": 0}}, | |||||
{"method": "GaussianBlur", "params": {"ksize": 5}}, | |||||
{"method": "SaltAndPepperNoise", "params": {"factor": 0.05}}, | |||||
{"method": "Translate", "params": {"x_bias": 0.1, "y_bias": -0.2}}, | |||||
{"method": "Scale", "params": {"factor_x": 0.7, "factor_y": 0.7}}, | |||||
{"method": "Shear", "params": {"factor": 2, "director": "horizontal"}}, | |||||
{"method": "Rotate", "params": {"angle": 40}}, | |||||
{"method": "MotionBlur", "params": {"degree": 5, "angle": 45}}, | |||||
{"method": "GradientBlur", "params": {"point": [50, 100], "kernel_num": 3, "center": True}}, | |||||
{"method": "GradientLuminance", | |||||
"params": {"color_start": [255, 255, 255], | |||||
"color_end": [0, 0, 0], | |||||
"start_point": [100, 150], "scope": 0.3, | |||||
"bright_rate": 0.3, "pattern": "light", | |||||
"mode": "circle"}}, | |||||
{"method": "Curve", "params": {"curves": 5, "depth": 10, | |||||
"mode": "vertical"}}, | |||||
{"method": "Perspective", | |||||
"params": {"ori_pos": [[0, 0], [0, 800], [800, 0], [800, 800]], | |||||
"dst_pos": [[50, 0], [0, 800], [780, 0], [800, 800]]}}, | |||||
] | |||||
``` | |||||
其中`method`为扰动方法名,`params`为对应方法的参数。可用的扰动方法及对应参数可在`mindarmour/natural_robustness/natural_noise.py`中查询。 | |||||
2. 在`serving_client.py`中写客户端的处理脚本,包含输入输出的处理、服务端的调用,可以参考下面的例子。 | |||||
```python | |||||
··· | |||||
def perturb(perturb_config): | |||||
"""invoke servable perturbation method natural_perturbation""" | |||||
# 请求的服务端ip及端口、请求的服务名、请求的方法名 | |||||
client = Client("10.175.122.87:5500", "perturbation", "natural_perturbation") | |||||
# 输入数据 | |||||
instances = [] | |||||
img_path = '/root/mindarmour/example/adversarial/test_data/1.png' | |||||
result_path = '/root/mindarmour/example/adv/result/' | |||||
methods_number = 2 | |||||
outputs_number = 3 | |||||
img = cv2.imread(img_path) | |||||
img = cv2.imencode('.png', img)[1].tobytes() # 图片传输用bytes格式,不支持numpy.ndarray格式 | |||||
perturb_config = json.dumps(perturb_config) # 配置方法转成json格式 | |||||
instances.append({"img": img, 'perturb_config': perturb_config, "methods_number": methods_number, | |||||
"outputs_number": outputs_number}) # instances中可添加多个输入 | |||||
# 请求服务,返回结果 | |||||
result = client.infer(instances) | |||||
# 对服务请求得到的结果进行处理,将返回的图片字节流存成图片 | |||||
file_names = result[0]['file_names'].split(';') | |||||
length = result[0]['file_length'].tolist() | |||||
before = 0 | |||||
for name, leng in zip(file_names, length): | |||||
res_img = result[0]['results'] | |||||
res_img = res_img[before:before + leng] | |||||
before = before + leng | |||||
print('name: ', name) | |||||
image = Image.open(BytesIO(res_img)) | |||||
image.save(os.path.join(result_path, name)) | |||||
names_dict = result[0]['names_dict'] | |||||
with open('names_dict.json', 'w') as file: | |||||
file.write(names_dict) | |||||
``` | |||||
启动client前,需将服务端的IP地址改成部署server的IP地址,图片路径、结果存储路基替换成用户数据路径。 | |||||
目前serving数据传输支持的数据类型包括:python的int、float、bool、str、bytes,numpy number, numpy array object。 | |||||
输入命令`python serving_client.py`开启客户端,如果对应目录下生成扰动样本图片则说明serving服务正确执行。 | |||||
### 其他 | |||||
在`serving_logs`目录下可以查看运行日志,辅助debug。 |
@@ -0,0 +1,41 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Configuration of natural robustness methods for server. | |||||
""" | |||||
perturb_configs = [{"method": "Contrast", "params": {"alpha": 1.5, "beta": 0}}, | |||||
{"method": "GaussianBlur", "params": {"ksize": 5}}, | |||||
{"method": "SaltAndPepperNoise", "params": {"factor": 0.05}}, | |||||
{"method": "Translate", "params": {"x_bias": 0.1, "y_bias": -0.2}}, | |||||
{"method": "Scale", "params": {"factor_x": 0.7, "factor_y": 0.7}}, | |||||
{"method": "Shear", "params": {"factor": 2, "direction": "horizontal"}}, | |||||
{"method": "Rotate", "params": {"angle": 40}}, | |||||
{"method": "MotionBlur", "params": {"degree": 5, "angle": 45}}, | |||||
{"method": "GradientBlur", "params": {"point": [50, 100], "kernel_num": 3, "center": True}}, | |||||
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0], | |||||
"start_point": [100, 150], "scope": 0.3, | |||||
"bright_rate": 0.3, "pattern": "light", | |||||
"mode": "circle"}}, | |||||
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], | |||||
"color_end": [0, 0, 0], "start_point": [150, 200], | |||||
"scope": 0.3, "pattern": "light", "mode": "horizontal"}}, | |||||
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0], | |||||
"start_point": [150, 200], "scope": 0.3, | |||||
"pattern": "light", "mode": "vertical"}}, | |||||
{"method": "Curve", "params": {"curves": 10, "depth": 10, "mode": "vertical"}}, | |||||
{"method": "Perspective", "params": {"ori_pos": [[0, 0], [0, 800], [800, 0], [800, 800]], | |||||
"dst_pos": [[50, 0], [0, 800], [780, 0], [800, 800]]}}, | |||||
] |
@@ -0,0 +1,61 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""The client of example add.""" | |||||
import os | |||||
import json | |||||
from io import BytesIO | |||||
import cv2 | |||||
from PIL import Image | |||||
from mindspore_serving.client import Client | |||||
from perturb_config import perturb_configs | |||||
def perturb(perturb_config): | |||||
"""Invoke servable perturbation method natural_perturbation""" | |||||
client = Client("0.0.0.0:5500", "perturbation", "natural_perturbation") | |||||
instances = [] | |||||
img_path = 'test_data/1.png' | |||||
result_path = 'result/' | |||||
if not os.path.exists(result_path): | |||||
os.mkdir(result_path) | |||||
methods_number = 2 | |||||
outputs_number = 10 | |||||
img = cv2.imread(img_path) | |||||
img = cv2.imencode('.png', img)[1].tobytes() | |||||
perturb_config = json.dumps(perturb_config) | |||||
instances.append({"img": img, 'perturb_config': perturb_config, "methods_number": methods_number, | |||||
"outputs_number": outputs_number}) | |||||
result = client.infer(instances) | |||||
file_names = result[0]['file_names'].split(';') | |||||
length = result[0]['file_length'].tolist() | |||||
before = 0 | |||||
for name, leng in zip(file_names, length): | |||||
res_img = result[0]['results'] | |||||
res_img = res_img[before:before + leng] | |||||
before = before + leng | |||||
print('name: ', name) | |||||
image = Image.open(BytesIO(res_img)) | |||||
image.save(os.path.join(result_path, name)) | |||||
names_dict = result[0]['names_dict'] | |||||
with open('names_dict.json', 'w') as file: | |||||
file.write(names_dict) | |||||
if __name__ == '__main__': | |||||
perturb(perturb_configs) |
@@ -0,0 +1,58 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""add model generator""" | |||||
import os | |||||
from shutil import copyfile | |||||
import numpy as np | |||||
import mindspore.context as context | |||||
import mindspore.nn as nn | |||||
import mindspore.ops as ops | |||||
import mindspore as ms | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
class Net(nn.Cell): | |||||
"""Define Net of add""" | |||||
def __init__(self): | |||||
super(Net, self).__init__() | |||||
self.add = ops.Add() | |||||
def construct(self, x_, y_): | |||||
"""construct add net""" | |||||
return self.add(x_, y_) | |||||
def export_net(): | |||||
"""Export add net of 2x2 + 2x2, and copy output model `tensor_add.mindir` to directory ../add/1""" | |||||
x = np.ones([2, 2]).astype(np.float32) | |||||
y = np.ones([2, 2]).astype(np.float32) | |||||
add = Net() | |||||
ms.export(add, ms.Tensor(x), ms.Tensor(y), file_name='tensor_add', file_format='MINDIR') | |||||
dst_dir = '../perturbation/1' | |||||
try: | |||||
os.mkdir(dst_dir) | |||||
except OSError: | |||||
pass | |||||
dst_file = os.path.join(dst_dir, 'tensor_add.mindir') | |||||
copyfile('tensor_add.mindir', dst_file) | |||||
print("copy tensor_add.mindir to " + dst_dir + " success") | |||||
if __name__ == "__main__": | |||||
export_net() |
@@ -0,0 +1,109 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""perturbation servable config""" | |||||
import json | |||||
import copy | |||||
import random | |||||
from io import BytesIO | |||||
import cv2 | |||||
import numpy as np | |||||
from PIL import Image | |||||
from mindspore_serving.server import register | |||||
from mindarmour.natural_robustness import Contrast, GaussianBlur, SaltAndPepperNoise, Scale, Shear, \ | |||||
Translate, Rotate, MotionBlur, GradientBlur, GradientLuminance, NaturalNoise, Curve, Perspective | |||||
CHARACTERS = [chr(i) for i in range(65, 91)]+[chr(j) for j in range(97, 123)] | |||||
methods_dict = {'Contrast': Contrast, | |||||
'GaussianBlur': GaussianBlur, | |||||
'SaltAndPepperNoise': SaltAndPepperNoise, | |||||
'Translate': Translate, | |||||
'Scale': Scale, | |||||
'Shear': Shear, | |||||
'Rotate': Rotate, | |||||
'MotionBlur': MotionBlur, | |||||
'GradientBlur': GradientBlur, | |||||
'GradientLuminance': GradientLuminance, | |||||
'NaturalNoise': NaturalNoise, | |||||
'Curve': Curve, | |||||
'Perspective': Perspective} | |||||
def check_inputs(img, perturb_config, methods_number, outputs_number): | |||||
"""Check inputs.""" | |||||
if not np.any(img): | |||||
raise ValueError("img cannot be empty.") | |||||
img = Image.open(BytesIO(img)) | |||||
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) | |||||
config = json.loads(perturb_config) | |||||
if not config: | |||||
raise ValueError("perturb_config cannot be empty.") | |||||
for item in config: | |||||
if item['method'] not in methods_dict.keys(): | |||||
raise ValueError("{} is not a valid method.".format(item['method'])) | |||||
methods_number = int(methods_number) | |||||
if methods_number < 1: | |||||
raise ValueError("methods_number must more than 0.") | |||||
outputs_number = int(outputs_number) | |||||
if outputs_number < 1: | |||||
raise ValueError("outputs_number must more than 0.") | |||||
return img, config, methods_number, outputs_number | |||||
def perturb(img, perturb_config, methods_number, outputs_number): | |||||
"""Perturb given image.""" | |||||
img, config, methods_number, outputs_number = check_inputs(img, perturb_config, methods_number, outputs_number) | |||||
res_img_bytes = b'' | |||||
file_names = [] | |||||
file_length = [] | |||||
names_dict = {} | |||||
for _ in range(outputs_number): | |||||
dst = copy.deepcopy(img) | |||||
used_methods = [] | |||||
for _ in range(methods_number): | |||||
item = np.random.choice(config) | |||||
method_name = item['method'] | |||||
method = methods_dict[method_name] | |||||
params = item['params'] | |||||
dst = method(**params)(img) | |||||
method_params = params | |||||
used_methods.append([method_name, method_params]) | |||||
name = ''.join(random.sample(CHARACTERS, 20)) | |||||
name += '.png' | |||||
file_names.append(name) | |||||
names_dict[name] = used_methods | |||||
res_img = cv2.imencode('.png', dst)[1].tobytes() | |||||
res_img_bytes += res_img | |||||
file_length.append(len(res_img)) | |||||
names_dict = json.dumps(names_dict) | |||||
return res_img_bytes, ';'.join(file_names), file_length, names_dict | |||||
model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) | |||||
@register.register_method(output_names=["results", "file_names", "file_length", "names_dict"]) | |||||
def natural_perturbation(img, perturb_config, methods_number, outputs_number): | |||||
"""method natural_perturbation data flow definition, only preprocessing and call model""" | |||||
res = register.add_stage(perturb, img, perturb_config, methods_number, outputs_number, outputs_count=4) | |||||
return res |
@@ -0,0 +1,35 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""The server of example perturbation""" | |||||
import os | |||||
import sys | |||||
from mindspore_serving import server | |||||
def start(): | |||||
"""Start server.""" | |||||
servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) | |||||
servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="perturbation", | |||||
device_ids=(0, 1), num_parallel_workers=4) | |||||
server.start_servables(servable_configs=servable_config) | |||||
server.start_grpc_server(address="0.0.0.0:5500", max_msg_mb_size=200) | |||||
# server.start_restful_server(address="0.0.0.0:5500") | |||||
if __name__ == "__main__": | |||||
start() |
@@ -92,7 +92,7 @@ def _projection(values, eps, norm_level): | |||||
return proj_flat.reshape(values.shape) | return proj_flat.reshape(values.shape) | ||||
if norm_level in (2, '2'): | if norm_level in (2, '2'): | ||||
return eps*normalize_value(values, norm_level) | return eps*normalize_value(values, norm_level) | ||||
if norm_level in (np.inf, 'inf'): | |||||
if norm_level in (np.inf, 'inf', 'linf', 'np.inf'): | |||||
return eps*np.sign(values) | return eps*np.sign(values) | ||||
msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \ | msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \ | ||||
'currently not supported.' | 'currently not supported.' | ||||
@@ -277,7 +277,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
decay_factor (float): Decay factor in iterations. Default: 1.0. | decay_factor (float): Decay factor in iterations. Default: 1.0. | ||||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
np.inf, 1 or 2. Default: 'inf'. | |||||
1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', np.inf and 'linf'. Default: 'inf'. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | loss_fn (Loss): Loss function for optimization. If None, the input network \ | ||||
is already equipped with loss function. Default: None. | is already equipped with loss function. Default: None. | ||||
@@ -423,7 +423,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
attack. Default: False. | attack. Default: False. | ||||
nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
np.inf, 1 or 2. Default: 'inf'. | |||||
1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', np.inf and 'linf'. Default: 'inf'. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | loss_fn (Loss): Loss function for optimization. If None, the input network \ | ||||
is already equipped with loss function. Default: None. | is already equipped with loss function. Default: None. | ||||
@@ -577,7 +577,7 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
np.inf, 1 or 2. Default: 'l1'. | |||||
1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', np.inf and 'linf'. Default: 'l1'. | |||||
prob (float): Transformation probability. Default: 0.5. | prob (float): Transformation probability. Default: 0.5. | ||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | loss_fn (Loss): Loss function for optimization. If None, the input network \ | ||||
is already equipped with loss function. Default: None. | is already equipped with loss function. Default: None. | ||||
@@ -24,10 +24,10 @@ from mindspore import nn | |||||
from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \ | ||||
check_param_in_range, check_param_type, check_int_positive, check_param_bounds | check_param_in_range, check_param_type, check_int_positive, check_param_bounds | ||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from ..adv_robustness.attacks import FastGradientSignMethod, \ | |||||
from mindarmour.adv_robustness.attacks import FastGradientSignMethod, \ | |||||
MomentumDiverseInputIterativeMethod, ProjectedGradientDescent | MomentumDiverseInputIterativeMethod, ProjectedGradientDescent | ||||
from .image_transform import Contrast, Brightness, Blur, \ | |||||
Noise, Translate, Scale, Shear, Rotate | |||||
from mindarmour.natural_robustness import GaussianBlur, MotionBlur, GradientBlur, UniformNoise, GaussianNoise, \ | |||||
SaltAndPepperNoise, NaturalNoise, Contrast, GradientLuminance, Translate, Scale, Shear, Rotate, Perspective, Curve | |||||
from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage | from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage | ||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
@@ -104,17 +104,79 @@ class Fuzzer: | |||||
target_model (Model): Target fuzz model. | target_model (Model): Target fuzz model. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> from mindspore import context | |||||
>>> from mindspore import nn | |||||
>>> from mindspore.common.initializer import TruncatedNormal | |||||
>>> from mindspore.ops import operations as P | |||||
>>> from mindspore.train import Model | |||||
>>> from mindspore.ops import TensorSummary | |||||
>>> from mindarmour.fuzz_testing import Fuzzer | |||||
>>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage | |||||
>>> | |||||
>>> class Net(nn.Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid") | |||||
>>> self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid") | |||||
>>> self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02)) | |||||
>>> self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02)) | |||||
>>> self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02)) | |||||
>>> self.relu = nn.ReLU() | |||||
>>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
>>> self.reshape = P.Reshape() | |||||
>>> self.summary = TensorSummary() | |||||
>>> | |||||
>>> def construct(self, x): | |||||
>>> x = self.conv1(x) | |||||
>>> x = self.relu(x) | |||||
>>> self.summary('conv1', x) | |||||
>>> x = self.max_pool2d(x) | |||||
>>> x = self.conv2(x) | |||||
>>> x = self.relu(x) | |||||
>>> self.summary('conv2', x) | |||||
>>> x = self.max_pool2d(x) | |||||
>>> x = self.reshape(x, (-1, 16 * 5 * 5)) | |||||
>>> x = self.fc1(x) | |||||
>>> x = self.relu(x) | |||||
>>> self.summary('fc1', x) | |||||
>>> x = self.fc2(x) | |||||
>>> x = self.relu(x) | |||||
>>> self.summary('fc2', x) | |||||
>>> x = self.fc3(x) | |||||
>>> self.summary('fc3', x) | |||||
>>> return x | |||||
>>> | |||||
>>> net = Net() | >>> net = Net() | ||||
>>> model = Model(net) | >>> model = Model(net) | ||||
>>> mutate_config = [{'method': 'Blur', | |||||
... 'params': {'auto_param': [True]}}, | |||||
>>> mutate_config = [{'method': 'GaussianBlur', | |||||
... 'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, | |||||
... {'method': 'MotionBlur', | |||||
... 'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], | |||||
... 'auto_param': [True]}}, | |||||
... {'method': 'UniformNoise', | |||||
... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
... {'method': 'GaussianNoise', | |||||
... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
... {'method': 'Contrast', | ... {'method': 'Contrast', | ||||
... 'params': {'factor': [2]}}, | |||||
... {'method': 'Translate', | |||||
... 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}}, | |||||
... 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
... {'method': 'Rotate', | |||||
... 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
... {'method': 'FGSM', | ... {'method': 'FGSM', | ||||
... 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}] | |||||
>>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) | |||||
... 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
>>> batch_size = 8 | |||||
>>> num_classe = 10 | |||||
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||||
>>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||||
>>> test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||||
>>> test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||||
>>> initial_seeds = [] | |||||
>>> # make initial seeds | |||||
>>> for img, label in zip(test_images, test_labels): | |||||
>>> initial_seeds.append([img, label]) | |||||
>>> initial_seeds = initial_seeds[:10] | |||||
>>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100, incremental=True) | |||||
>>> model_fuzz_test = Fuzzer(model) | >>> model_fuzz_test = Fuzzer(model) | ||||
>>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, | >>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, | ||||
... nc, max_iters=100) | ... nc, max_iters=100) | ||||
@@ -125,18 +187,26 @@ class Fuzzer: | |||||
# Allowed mutate strategies so far. | # Allowed mutate strategies so far. | ||||
self._strategies = {'Contrast': Contrast, | self._strategies = {'Contrast': Contrast, | ||||
'Brightness': Brightness, | |||||
'Blur': Blur, | |||||
'Noise': Noise, | |||||
'GradientLuminance': GradientLuminance, | |||||
'GaussianBlur': GaussianBlur, | |||||
'MotionBlur': MotionBlur, | |||||
'GradientBlur': GradientBlur, | |||||
'UniformNoise': UniformNoise, | |||||
'GaussianNoise': GaussianNoise, | |||||
'SaltAndPepperNoise': SaltAndPepperNoise, | |||||
'NaturalNoise': NaturalNoise, | |||||
'Translate': Translate, | 'Translate': Translate, | ||||
'Scale': Scale, | 'Scale': Scale, | ||||
'Shear': Shear, | 'Shear': Shear, | ||||
'Rotate': Rotate, | 'Rotate': Rotate, | ||||
'Perspective': Perspective, | |||||
'Curve': Curve, | |||||
'FGSM': FastGradientSignMethod, | 'FGSM': FastGradientSignMethod, | ||||
'PGD': ProjectedGradientDescent, | 'PGD': ProjectedGradientDescent, | ||||
'MDIIM': MomentumDiverseInputIterativeMethod} | 'MDIIM': MomentumDiverseInputIterativeMethod} | ||||
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||||
self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', 'Noise'] | |||||
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate', 'Perspective', 'Curve'] | |||||
self._pixel_value_trans_list = ['Contrast', 'GradientLuminance', 'GaussianBlur', 'MotionBlur', 'GradientBlur', | |||||
'UniformNoise', 'GaussianNoise', 'SaltAndPepperNoise', 'NaturalNoise'] | |||||
self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] | self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] | ||||
self._attack_param_checklists = { | self._attack_param_checklists = { | ||||
'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]}, | 'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]}, | ||||
@@ -144,10 +214,11 @@ class Fuzzer: | |||||
'bounds': {'dtype': [tuple, list]}}, | 'bounds': {'dtype': [tuple, list]}}, | ||||
'PGD': {'eps': {'dtype': [float], 'range': [0, 1]}, | 'PGD': {'eps': {'dtype': [float], 'range': [0, 1]}, | ||||
'eps_iter': {'dtype': [float], 'range': [0, 1]}, | 'eps_iter': {'dtype': [float], 'range': [0, 1]}, | ||||
'nb_iter': {'dtype': [int], 'range': [0, 100000]}, | |||||
'nb_iter': {'dtype': [int]}, | |||||
'bounds': {'dtype': [tuple, list]}}, | 'bounds': {'dtype': [tuple, list]}}, | ||||
'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]}, | 'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]}, | ||||
'norm_level': {'dtype': [str, int], 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf']}, | |||||
'norm_level': {'dtype': [str, int], | |||||
'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf']}, | |||||
'prob': {'dtype': [float], 'range': [0, 1]}, | 'prob': {'dtype': [float], 'range': [0, 1]}, | ||||
'bounds': {'dtype': [tuple, list]}}} | 'bounds': {'dtype': [tuple, list]}}} | ||||
@@ -157,18 +228,26 @@ class Fuzzer: | |||||
Args: | Args: | ||||
mutate_config (list): Mutate configs. The format is | mutate_config (list): Mutate configs. The format is | ||||
[{'method': 'Blur', | |||||
'params': {'radius': [0.1, 0.2], 'auto_param': [True, False]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'factor': [1, 1.5, 2]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}}, | |||||
...]. | |||||
[{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'GaussianNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
...]. | |||||
The supported methods list is in `self._strategies`, and the params of each method must within the | The supported methods list is in `self._strategies`, and the params of each method must within the | ||||
range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based | range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based | ||||
transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform | transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform | ||||
methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM', | methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM', | ||||
'PGD' and 'MDIIM'. `mutate_config` must have method in the type of pixel value based transform methods. | |||||
'PGD' and 'MDIIM'. 'FGSM', 'PGD' and 'MDIIM'. are abbreviations of FastGradientSignMethod, | |||||
ProjectedGradientDescent and MomentumDiverseInputIterativeMethod. | |||||
`mutate_config` must have method in the type of pixel value based transform methods. | |||||
The way of setting parameters for first and second type methods can be seen in | The way of setting parameters for first and second type methods can be seen in | ||||
'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to | 'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to | ||||
`self._attack_param_checklists`. | `self._attack_param_checklists`. | ||||
@@ -278,7 +357,6 @@ class Fuzzer: | |||||
if only_pixel_trans: | if only_pixel_trans: | ||||
while strategy['method'] not in self._pixel_value_trans_list: | while strategy['method'] not in self._pixel_value_trans_list: | ||||
strategy = choice(mutate_config) | strategy = choice(mutate_config) | ||||
transform = mutates[strategy['method']] | |||||
params = strategy['params'] | params = strategy['params'] | ||||
method = strategy['method'] | method = strategy['method'] | ||||
selected_param = {} | selected_param = {} | ||||
@@ -290,9 +368,10 @@ class Fuzzer: | |||||
shear_keys = selected_param.keys() | shear_keys = selected_param.keys() | ||||
if 'factor_x' in shear_keys and 'factor_y' in shear_keys: | if 'factor_x' in shear_keys and 'factor_y' in shear_keys: | ||||
selected_param[choice(['factor_x', 'factor_y'])] = 0 | selected_param[choice(['factor_x', 'factor_y'])] = 0 | ||||
transform.set_params(**selected_param) | |||||
mutate_sample = transform.transform(seed[0]) | |||||
transform = mutates[strategy['method']](**selected_param) | |||||
mutate_sample = transform(seed[0]) | |||||
else: | else: | ||||
transform = mutates[strategy['method']] | |||||
for param_name in selected_param: | for param_name in selected_param: | ||||
transform.__setattr__('_' + str(param_name), selected_param[param_name]) | transform.__setattr__('_' + str(param_name), selected_param[param_name]) | ||||
mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0] | mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0] | ||||
@@ -360,6 +439,8 @@ class Fuzzer: | |||||
_ = check_param_bounds('bounds', param_value) | _ = check_param_bounds('bounds', param_value) | ||||
elif param_name == 'norm_level': | elif param_name == 'norm_level': | ||||
_ = check_norm_level(param_value) | _ = check_norm_level(param_value) | ||||
elif param_name == 'nb_iter': | |||||
_ = check_int_positive(param_name, param_value) | |||||
else: | else: | ||||
allow_type = self._attack_param_checklists[method][param_name]['dtype'] | allow_type = self._attack_param_checklists[method][param_name]['dtype'] | ||||
allow_range = self._attack_param_checklists[method][param_name]['range'] | allow_range = self._attack_param_checklists[method][param_name]['range'] | ||||
@@ -372,7 +453,8 @@ class Fuzzer: | |||||
for mutate in mutate_config: | for mutate in mutate_config: | ||||
method = mutate['method'] | method = mutate['method'] | ||||
if method not in self._attacks_list: | if method not in self._attacks_list: | ||||
mutates[method] = self._strategies[method]() | |||||
# mutates[method] = self._strategies[method]() | |||||
mutates[method] = self._strategies[method] | |||||
else: | else: | ||||
network = self._target_model._network | network = self._target_model._network | ||||
loss_fn = self._target_model._loss_fn | loss_fn = self._target_model._loss_fn | ||||
@@ -414,7 +496,6 @@ class Fuzzer: | |||||
else: | else: | ||||
attack_success_rate = None | attack_success_rate = None | ||||
metrics_report['Attack_success_rate'] = attack_success_rate | metrics_report['Attack_success_rate'] = attack_success_rate | ||||
metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples) | metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples) | ||||
return metrics_report | return metrics_report |
@@ -1,609 +0,0 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image transform | |||||
""" | |||||
import numpy as np | |||||
from PIL import Image, ImageEnhance, ImageFilter | |||||
from mindspore.dataset.vision.py_transforms_util import is_numpy, \ | |||||
to_pil, hwc_to_chw | |||||
from mindarmour.utils._check_param import check_param_multi_types, check_param_in_range, check_numpy_param | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image Transformation' | |||||
def chw_to_hwc(img): | |||||
""" | |||||
Transpose the input image; shape (C, H, W) to shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be converted. | |||||
Returns: | |||||
img (numpy.ndarray), Converted image. | |||||
""" | |||||
if is_numpy(img): | |||||
return img.transpose(1, 2, 0).copy() | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def is_hwc(img): | |||||
""" | |||||
Check if the input image is shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is shape (H, W, C). | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def is_chw(img): | |||||
""" | |||||
Check if the input image is shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is shape (H, W, C). | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def is_rgb(img): | |||||
""" | |||||
Check if the input image is RGB. | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is RGB. | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3): | |||||
return True | |||||
return False | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def is_normalized(img): | |||||
""" | |||||
Check if the input image is normalized between 0 to 1. | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is normalized between 0 to 1. | |||||
""" | |||||
if is_numpy(img): | |||||
minimal = np.min(img) | |||||
maximun = np.max(img) | |||||
if minimal >= 0 and maximun <= 1: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||||
class ImageTransform: | |||||
""" | |||||
The abstract base class for all image transform classes. | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
def _check(self, image): | |||||
""" Check image format. If input image is RGB and its shape | |||||
is (C, H, W), it will be transposed to (H, W, C). If the value | |||||
of the image is not normalized , it will be normalized between 0 to 1.""" | |||||
rgb = is_rgb(image) | |||||
chw = False | |||||
gray3dim = False | |||||
normalized = is_normalized(image) | |||||
if rgb: | |||||
chw = is_chw(image) | |||||
if chw: | |||||
image = chw_to_hwc(image) | |||||
else: | |||||
image = image | |||||
else: | |||||
if len(np.shape(image)) == 3: | |||||
gray3dim = True | |||||
image = image[0] | |||||
else: | |||||
image = image | |||||
if normalized: | |||||
image = image*255 | |||||
return rgb, chw, normalized, gray3dim, np.uint8(image) | |||||
def _original_format(self, image, chw, normalized, gray3dim): | |||||
""" Return transformed image with original format. """ | |||||
if not is_numpy(image): | |||||
image = np.array(image) | |||||
if chw: | |||||
image = hwc_to_chw(image) | |||||
if normalized: | |||||
image = image / 255 | |||||
if gray3dim: | |||||
image = np.expand_dims(image, 0) | |||||
return image | |||||
def transform(self, image): | |||||
pass | |||||
class Contrast(ImageTransform): | |||||
""" | |||||
Contrast of an image. | |||||
Args: | |||||
factor (Union[float, int]): Control the contrast of an image. If 1.0, | |||||
gives the original image. If 0, gives a gray image. Default: 1. | |||||
""" | |||||
def __init__(self, factor=1): | |||||
super(Contrast, self).__init__() | |||||
self.set_params(factor) | |||||
def set_params(self, factor=1, auto_param=False): | |||||
""" | |||||
Set contrast parameters. | |||||
Args: | |||||
factor (Union[float, int]): Control the contrast of an image. If 1.0 | |||||
gives the original image. If 0 gives a gray image. Default: 1. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor = np.random.uniform(-5, 5) | |||||
else: | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
image = to_pil(image) | |||||
img_contrast = ImageEnhance.Contrast(image) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Brightness(ImageTransform): | |||||
""" | |||||
Brightness of an image. | |||||
Args: | |||||
factor (Union[float, int]): Control the brightness of an image. If 1.0 | |||||
gives the original image. If 0 gives a black image. Default: 1. | |||||
""" | |||||
def __init__(self, factor=1): | |||||
super(Brightness, self).__init__() | |||||
self.set_params(factor) | |||||
def set_params(self, factor=1, auto_param=False): | |||||
""" | |||||
Set brightness parameters. | |||||
Args: | |||||
factor (Union[float, int]): Control the brightness of an image. If 1 | |||||
gives the original image. If 0 gives a black image. Default: 1. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0, 5) | |||||
else: | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
image = to_pil(image) | |||||
img_contrast = ImageEnhance.Brightness(image) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Blur(ImageTransform): | |||||
""" | |||||
Blurs the image using Gaussian blur filter. | |||||
Args: | |||||
radius(Union[float, int]): Blur radius, 0 means no blur. Default: 0. | |||||
""" | |||||
def __init__(self, radius=0): | |||||
super(Blur, self).__init__() | |||||
self.set_params(radius) | |||||
def set_params(self, radius=0, auto_param=False): | |||||
""" | |||||
Set blur parameters. | |||||
Args: | |||||
radius (Union[float, int]): Blur radius, 0 means no blur. Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.radius = np.random.uniform(-1.5, 1.5) | |||||
else: | |||||
self.radius = check_param_multi_types('radius', radius, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
image = to_pil(image) | |||||
trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Noise(ImageTransform): | |||||
""" | |||||
Add noise of an image. | |||||
Args: | |||||
factor (float): factor is the ratio of pixels to add noise. | |||||
If 0 gives the original image. Default 0. | |||||
""" | |||||
def __init__(self, factor=0): | |||||
super(Noise, self).__init__() | |||||
self.set_params(factor) | |||||
def set_params(self, factor=0, auto_param=False): | |||||
""" | |||||
Set noise parameters. | |||||
Args: | |||||
factor (Union[float, int]): factor is the ratio of pixels to | |||||
add noise. If 0 gives the original image. Default 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0, 1) | |||||
else: | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) | |||||
trans_image = np.copy(image) | |||||
threshold = 1 - self.factor | |||||
trans_image[noise < -threshold] = 0 | |||||
trans_image[noise > threshold] = 1 | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Translate(ImageTransform): | |||||
""" | |||||
Translate an image. | |||||
Args: | |||||
x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length. | |||||
Default: 0. | |||||
y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide. | |||||
Default: 0. | |||||
""" | |||||
def __init__(self, x_bias=0, y_bias=0): | |||||
super(Translate, self).__init__() | |||||
self.set_params(x_bias, y_bias) | |||||
def set_params(self, x_bias=0, y_bias=0, auto_param=False): | |||||
""" | |||||
Set translate parameters. | |||||
Args: | |||||
x_bias (Union[float, int]): X-direction translation, and x_bias should be in range of (-1, 1). Default: 0. | |||||
y_bias (Union[float, int]): Y-direction translation, and y_bias should be in range of (-1, 1). Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
x_bias = check_param_in_range('x_bias', x_bias, -1, 1) | |||||
y_bias = check_param_in_range('y_bias', y_bias, -1, 1) | |||||
self.auto_param = auto_param | |||||
if auto_param: | |||||
self.x_bias = np.random.uniform(-0.3, 0.3) | |||||
self.y_bias = np.random.uniform(-0.3, 0.3) | |||||
else: | |||||
self.x_bias = check_param_multi_types('x_bias', x_bias, | |||||
[int, float]) | |||||
self.y_bias = check_param_multi_types('y_bias', y_bias, | |||||
[int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
img = to_pil(image) | |||||
image_shape = np.shape(image) | |||||
self.x_bias = image_shape[1]*self.x_bias | |||||
self.y_bias = image_shape[0]*self.y_bias | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(1, 0, self.x_bias, 0, 1, self.y_bias)) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Scale(ImageTransform): | |||||
""" | |||||
Scale an image in the middle. | |||||
Args: | |||||
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. | |||||
Default: 1. | |||||
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. | |||||
Default: 1. | |||||
""" | |||||
def __init__(self, factor_x=1, factor_y=1): | |||||
super(Scale, self).__init__() | |||||
self.set_params(factor_x, factor_y) | |||||
def set_params(self, factor_x=1, factor_y=1, auto_param=False): | |||||
""" | |||||
Set scale parameters. | |||||
Args: | |||||
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. | |||||
Default: 1. | |||||
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. | |||||
Default: 1. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor_x = np.random.uniform(0.7, 3) | |||||
self.factor_y = np.random.uniform(0.7, 3) | |||||
else: | |||||
self.factor_x = check_param_multi_types('factor_x', factor_x, | |||||
[int, float]) | |||||
self.factor_y = check_param_multi_types('factor_y', factor_y, | |||||
[int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
rgb, chw, normalized, gray3dim, image = self._check(image) | |||||
if rgb: | |||||
h, w, _ = np.shape(image) | |||||
else: | |||||
h, w = np.shape(image) | |||||
move_x_centor = w / 2*(1 - self.factor_x) | |||||
move_y_centor = h / 2*(1 - self.factor_y) | |||||
img = to_pil(image) | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(self.factor_x, 0, move_x_centor, | |||||
0, self.factor_y, move_y_centor)) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Shear(ImageTransform): | |||||
""" | |||||
Shear an image, for each pixel (x, y) in the sheared image, the new value is | |||||
taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then | |||||
the sheared image will be rescaled to fit original size. | |||||
Args: | |||||
factor_x (Union[float, int]): Shear factor of horizontal direction. | |||||
Default: 0. | |||||
factor_y (Union[float, int]): Shear factor of vertical direction. | |||||
Default: 0. | |||||
""" | |||||
def __init__(self, factor_x=0, factor_y=0): | |||||
super(Shear, self).__init__() | |||||
self.set_params(factor_x, factor_y) | |||||
def set_params(self, factor_x=0, factor_y=0, auto_param=False): | |||||
""" | |||||
Set shear parameters. | |||||
Args: | |||||
factor_x (Union[float, int]): Shear factor of horizontal direction. | |||||
Default: 0. | |||||
factor_y (Union[float, int]): Shear factor of vertical direction. | |||||
Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if factor_x != 0 and factor_y != 0: | |||||
msg = 'At least one of factor_x and factor_y is zero.' | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if auto_param: | |||||
if np.random.uniform(-1, 1) > 0: | |||||
self.factor_x = np.random.uniform(-2, 2) | |||||
self.factor_y = 0 | |||||
else: | |||||
self.factor_x = 0 | |||||
self.factor_y = np.random.uniform(-2, 2) | |||||
else: | |||||
self.factor_x = check_param_multi_types('factor', factor_x, | |||||
[int, float]) | |||||
self.factor_y = check_param_multi_types('factor', factor_y, | |||||
[int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
rgb, chw, normalized, gray3dim, image = self._check(image) | |||||
img = to_pil(image) | |||||
if rgb: | |||||
h, w, _ = np.shape(image) | |||||
else: | |||||
h, w = np.shape(image) | |||||
if self.factor_x != 0: | |||||
boarder_x = [0, -w, -self.factor_x*h, -w - self.factor_x*h] | |||||
min_x = min(boarder_x) | |||||
max_x = max(boarder_x) | |||||
scale = (max_x - min_x) / w | |||||
move_x_cen = (w - scale*w - scale*h*self.factor_x) / 2 | |||||
move_y_cen = h*(1 - scale) / 2 | |||||
else: | |||||
boarder_y = [0, -h, -self.factor_y*w, -h - self.factor_y*w] | |||||
min_y = min(boarder_y) | |||||
max_y = max(boarder_y) | |||||
scale = (max_y - min_y) / h | |||||
move_y_cen = (h - scale*h - scale*w*self.factor_y) / 2 | |||||
move_x_cen = w*(1 - scale) / 2 | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(scale, scale*self.factor_x, move_x_cen, | |||||
scale*self.factor_y, scale, move_y_cen)) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Rotate(ImageTransform): | |||||
""" | |||||
Rotate an image of degrees counter clockwise around its center. | |||||
Args: | |||||
angle(Union[float, int]): Degrees counter clockwise. Default: 0. | |||||
""" | |||||
def __init__(self, angle=0): | |||||
super(Rotate, self).__init__() | |||||
self.set_params(angle) | |||||
def set_params(self, angle=0, auto_param=False): | |||||
""" | |||||
Set rotate parameters. | |||||
Args: | |||||
angle(Union[float, int]): Degrees counter clockwise. Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.angle = np.random.uniform(0, 360) | |||||
else: | |||||
self.angle = check_param_multi_types('angle', angle, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
img = to_pil(image) | |||||
trans_image = img.rotate(self.angle, expand=False) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) |
@@ -154,13 +154,48 @@ class NeuronCoverage(CoverageMetrics): | |||||
incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | ||||
batch_size (int): The number of samples in a fuzz test batch. Default: 32. | batch_size (int): The number of samples in a fuzz test batch. Default: 32. | ||||
Examples: | |||||
>>> import numpy as np | |||||
>>> from mindspore import nn | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindspore.train import Model | |||||
>>> from mindspore import context | |||||
>>> from mindspore.ops import TensorSummary | |||||
>>> from mindarmour.fuzz_testing import NeuronCoverage | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> self.summary = TensorSummary() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> self.summary('input', inputs) | |||||
>>> out = self._relu(inputs) | |||||
>>> self.summary('1', out) | |||||
>>> return out | |||||
>>> | |||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
>>> # load network | |||||
>>> net = Net() | |||||
>>> model = Model(net) | |||||
>>> | |||||
>>> # initialize fuzz test with training dataset | |||||
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||||
>>> | |||||
>>> # fuzz test with original test data | |||||
>>> # get test data | |||||
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||||
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | |||||
>>> | |||||
>>> nc = NeuronCoverage(model, threshold=0.1) | |||||
>>> nc_metric = nc.get_metrics(test_data) | |||||
""" | """ | ||||
def __init__(self, model, threshold=0.1, incremental=False, batch_size=32): | def __init__(self, model, threshold=0.1, incremental=False, batch_size=32): | ||||
super(NeuronCoverage, self).__init__(model, incremental, batch_size) | super(NeuronCoverage, self).__init__(model, incremental, batch_size) | ||||
threshold = check_param_type('threshold', threshold, float) | threshold = check_param_type('threshold', threshold, float) | ||||
self.threshold = check_value_positive('threshold', threshold) | self.threshold = check_value_positive('threshold', threshold) | ||||
def get_metrics(self, dataset): | def get_metrics(self, dataset): | ||||
""" | """ | ||||
Get the metric of neuron coverage: the proportion of activated neurons to total neurons in the network. | Get the metric of neuron coverage: the proportion of activated neurons to total neurons in the network. | ||||
@@ -170,10 +205,6 @@ class NeuronCoverage(CoverageMetrics): | |||||
Returns: | Returns: | ||||
float, the metric of 'neuron coverage'. | float, the metric of 'neuron coverage'. | ||||
Examples: | |||||
>>> nc = NeuronCoverage(model, threshold=0.1) | |||||
>>> nc_metrics = nc.get_metrics(test_data) | |||||
""" | """ | ||||
dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
batches = math.ceil(dataset.shape[0] / self.batch_size) | batches = math.ceil(dataset.shape[0] / self.batch_size) | ||||
@@ -203,6 +234,43 @@ class TopKNeuronCoverage(CoverageMetrics): | |||||
top_k (int): Neuron is activated when its output has the top k largest value in that hidden layers. Default: 3. | top_k (int): Neuron is activated when its output has the top k largest value in that hidden layers. Default: 3. | ||||
incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | ||||
batch_size (int): The number of samples in a fuzz test batch. Default: 32. | batch_size (int): The number of samples in a fuzz test batch. Default: 32. | ||||
Examples: | |||||
>>> import numpy as np | |||||
>>> from mindspore import nn | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindspore.train import Model | |||||
>>> from mindspore import context | |||||
>>> from mindspore.ops import TensorSummary | |||||
>>> from mindarmour.fuzz_testing import TopKNeuronCoverage | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> self.summary = TensorSummary() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> self.summary('input', inputs) | |||||
>>> out = self._relu(inputs) | |||||
>>> self.summary('1', out) | |||||
>>> return out | |||||
>>> | |||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
>>> # load network | |||||
>>> net = Net() | |||||
>>> model = Model(net) | |||||
>>> | |||||
>>> # initialize fuzz test with training dataset | |||||
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||||
>>> | |||||
>>> # fuzz test with original test data | |||||
>>> # get test data | |||||
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||||
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | |||||
>>> | |||||
>>> tknc = TopKNeuronCoverage(model, top_k=3) | |||||
>>> tknc_metrics = tknc.get_metrics(test_data) | |||||
""" | """ | ||||
def __init__(self, model, top_k=3, incremental=False, batch_size=32): | def __init__(self, model, top_k=3, incremental=False, batch_size=32): | ||||
super(TopKNeuronCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) | super(TopKNeuronCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) | ||||
@@ -217,10 +285,6 @@ class TopKNeuronCoverage(CoverageMetrics): | |||||
Returns: | Returns: | ||||
float, the metrics of 'top k neuron coverage'. | float, the metrics of 'top k neuron coverage'. | ||||
Examples: | |||||
>>> tknc = TopKNeuronCoverage(model, top_k=3) | |||||
>>> metrics = tknc.get_metrics(test_data) | |||||
""" | """ | ||||
dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
batches = math.ceil(dataset.shape[0] / self.batch_size) | batches = math.ceil(dataset.shape[0] / self.batch_size) | ||||
@@ -252,6 +316,43 @@ class SuperNeuronActivateCoverage(CoverageMetrics): | |||||
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. | train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. | ||||
incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | ||||
batch_size (int): The number of samples in a fuzz test batch. Default: 32. | batch_size (int): The number of samples in a fuzz test batch. Default: 32. | ||||
Examples: | |||||
>>> import numpy as np | |||||
>>> from mindspore import nn | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindspore.train import Model | |||||
>>> from mindspore import context | |||||
>>> from mindspore.ops import TensorSummary | |||||
>>> from mindarmour.fuzz_testing import SuperNeuronActivateCoverage | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> self.summary = TensorSummary() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> self.summary('input', inputs) | |||||
>>> out = self._relu(inputs) | |||||
>>> self.summary('1', out) | |||||
>>> return out | |||||
>>> | |||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
>>> # load network | |||||
>>> net = Net() | |||||
>>> model = Model(net) | |||||
>>> | |||||
>>> # initialize fuzz test with training dataset | |||||
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||||
>>> | |||||
>>> # fuzz test with original test data | |||||
>>> # get test data | |||||
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||||
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | |||||
>>> | |||||
>>> snac = SuperNeuronActivateCoverage(model, training_data) | |||||
>>> snac_metrics = snac.get_metrics(test_data) | |||||
""" | """ | ||||
def __init__(self, model, train_dataset, incremental=False, batch_size=32): | def __init__(self, model, train_dataset, incremental=False, batch_size=32): | ||||
super(SuperNeuronActivateCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) | super(SuperNeuronActivateCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) | ||||
@@ -267,10 +368,6 @@ class SuperNeuronActivateCoverage(CoverageMetrics): | |||||
Returns: | Returns: | ||||
float, the metric of 'strong neuron activation coverage'. | float, the metric of 'strong neuron activation coverage'. | ||||
Examples: | |||||
>>> snac = SuperNeuronActivateCoverage(model, train_dataset) | |||||
>>> metrics = snac.get_metrics(test_data) | |||||
""" | """ | ||||
dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
if not self.incremental or not self._activate_table: | if not self.incremental or not self._activate_table: | ||||
@@ -303,6 +400,43 @@ class NeuronBoundsCoverage(SuperNeuronActivateCoverage): | |||||
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. | train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. | ||||
incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | ||||
batch_size (int): The number of samples in a fuzz test batch. Default: 32. | batch_size (int): The number of samples in a fuzz test batch. Default: 32. | ||||
Examples: | |||||
>>> import numpy as np | |||||
>>> from mindspore import nn | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindspore.train import Model | |||||
>>> from mindspore import context | |||||
>>> from mindspore.ops import TensorSummary | |||||
>>> from mindarmour.fuzz_testing import NeuronBoundsCoverage | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> self.summary = TensorSummary() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> self.summary('input', inputs) | |||||
>>> out = self._relu(inputs) | |||||
>>> self.summary('1', out) | |||||
>>> return out | |||||
>>> | |||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
>>> # load network | |||||
>>> net = Net() | |||||
>>> model = Model(net) | |||||
>>> | |||||
>>> # initialize fuzz test with training dataset | |||||
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||||
>>> | |||||
>>> # fuzz test with original test data | |||||
>>> # get test data | |||||
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||||
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | |||||
>>> | |||||
>>> nbc = NeuronBoundsCoverage(model, training_data) | |||||
>>> nbc_metrics = nbc.get_metrics(test_data) | |||||
""" | """ | ||||
def __init__(self, model, train_dataset, incremental=False, batch_size=32): | def __init__(self, model, train_dataset, incremental=False, batch_size=32): | ||||
@@ -317,10 +451,6 @@ class NeuronBoundsCoverage(SuperNeuronActivateCoverage): | |||||
Returns: | Returns: | ||||
float, the metric of 'neuron boundary coverage'. | float, the metric of 'neuron boundary coverage'. | ||||
Examples: | |||||
>>> nbc = NeuronBoundsCoverage(model, train_dataset) | |||||
>>> metrics = nbc.get_metrics(test_data) | |||||
""" | """ | ||||
dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
if not self.incremental or not self._activate_table: | if not self.incremental or not self._activate_table: | ||||
@@ -353,6 +483,43 @@ class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage): | |||||
segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 100. | segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 100. | ||||
incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | incremental (bool): Metrics will be calculate in incremental way or not. Default: False. | ||||
batch_size (int): The number of samples in a fuzz test batch. Default: 32. | batch_size (int): The number of samples in a fuzz test batch. Default: 32. | ||||
Examples: | |||||
>>> import numpy as np | |||||
>>> from mindspore import nn | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindspore.train import Model | |||||
>>> from mindspore import context | |||||
>>> from mindspore.ops import TensorSummary | |||||
>>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> self.summary = TensorSummary() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> self.summary('input', inputs) | |||||
>>> out = self._relu(inputs) | |||||
>>> self.summary('1', out) | |||||
>>> return out | |||||
>>> | |||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
>>> # load network | |||||
>>> net = Net() | |||||
>>> model = Model(net) | |||||
>>> | |||||
>>> # initialize fuzz test with training dataset | |||||
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||||
>>> | |||||
>>> # fuzz test with original test data | |||||
>>> # get test data | |||||
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||||
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | |||||
>>> | |||||
>>> kmnc = KMultisectionNeuronCoverage(model, training_data, segmented_num=100) | |||||
>>> kmnc_metrics = kmnc.get_metrics(test_data) | |||||
""" | """ | ||||
def __init__(self, model, train_dataset, segmented_num=100, incremental=False, batch_size=32): | def __init__(self, model, train_dataset, segmented_num=100, incremental=False, batch_size=32): | ||||
@@ -381,10 +548,6 @@ class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage): | |||||
Returns: | Returns: | ||||
float, the metric of 'k-multisection neuron coverage'. | float, the metric of 'k-multisection neuron coverage'. | ||||
Examples: | |||||
>>> kmnc = KMultisectionNeuronCoverage(model, train_dataset, segmented_num=100) | |||||
>>> metrics = kmnc.get_metrics(test_data) | |||||
""" | """ | ||||
dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
@@ -0,0 +1,37 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
This package include methods to generate natural perturbation samples. | |||||
""" | |||||
from .transformation import Translate, Scale, Shear, Rotate, Perspective, Curve | |||||
from .blur import GaussianBlur, MotionBlur, GradientBlur | |||||
from .luminance import Contrast, GradientLuminance | |||||
from .corruption import UniformNoise, GaussianNoise, SaltAndPepperNoise, NaturalNoise | |||||
__all__ = ['Translate', | |||||
'Scale', | |||||
'Shear', | |||||
'Rotate', | |||||
'Perspective', | |||||
'Curve', | |||||
'GaussianBlur', | |||||
'MotionBlur', | |||||
'GradientBlur', | |||||
'Contrast', | |||||
'GradientLuminance', | |||||
'UniformNoise', | |||||
'GaussianNoise', | |||||
'SaltAndPepperNoise', | |||||
'NaturalNoise'] |
@@ -0,0 +1,193 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image Blur | |||||
""" | |||||
import numpy as np | |||||
import cv2 | |||||
from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb | |||||
from mindarmour.utils._check_param import check_param_multi_types, check_int_positive, check_param_type | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image Blur' | |||||
class GaussianBlur(_NaturalPerturb): | |||||
""" | |||||
Blurs the image using Gaussian blur filter. | |||||
Args: | |||||
ksize (int): Size of gaussian kernel, this value must be non-negnative. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> ksize = 5 | |||||
>>> trans = GaussianBlur(ksize) | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, ksize=2, auto_param=False): | |||||
super(GaussianBlur, self).__init__() | |||||
ksize = check_int_positive('ksize', ksize) | |||||
if auto_param: | |||||
ksize = 2 * np.random.randint(0, 5) + 1 | |||||
else: | |||||
ksize = 2 * ksize + 1 | |||||
self.ksize = (ksize, ksize) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
new_img = cv2.GaussianBlur(image, self.ksize, 0) | |||||
new_img = self._original_format(new_img, chw, normalized, gray3dim) | |||||
return new_img.astype(ori_dtype) | |||||
class MotionBlur(_NaturalPerturb): | |||||
""" | |||||
Motion blur for a given image. | |||||
Args: | |||||
degree (int): Degree of blur. This value must be positive. Suggested value range in [1, 15]. | |||||
angle: (union[float, int]): Direction of motion blur. Angle=0 means up and down motion blur. Angle is | |||||
counterclockwise. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> angle = 0 | |||||
>>> degree = 5 | |||||
>>> trans = MotionBlur(degree=degree, angle=angle) | |||||
>>> new_img = trans(img) | |||||
""" | |||||
def __init__(self, degree=5, angle=45, auto_param=False): | |||||
super(MotionBlur, self).__init__() | |||||
self.degree = check_int_positive('degree', degree) | |||||
self.degree = check_param_multi_types('degree', degree, [float, int]) | |||||
auto_param = check_param_type('auto_param', auto_param, bool) | |||||
if auto_param: | |||||
self.degree = np.random.randint(1, 5) | |||||
self.angle = np.random.uniform(0, 360) | |||||
else: | |||||
self.angle = angle - 45 | |||||
def __call__(self, image): | |||||
""" | |||||
Motion blur for a given image. | |||||
Args: | |||||
image (numpy.ndarray): Original image. | |||||
Returns: | |||||
numpy.ndarray, image after motion blur. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
matrix = cv2.getRotationMatrix2D((self.degree / 2, self.degree / 2), self.angle, 1) | |||||
motion_blur_kernel = np.diag(np.ones(self.degree)) | |||||
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, matrix, (self.degree, self.degree)) | |||||
motion_blur_kernel = motion_blur_kernel / self.degree | |||||
blurred = cv2.filter2D(image, -1, motion_blur_kernel) | |||||
# convert to uint8 | |||||
cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX) | |||||
blurred = self._original_format(blurred, chw, normalized, gray3dim) | |||||
return blurred.astype(ori_dtype) | |||||
class GradientBlur(_NaturalPerturb): | |||||
""" | |||||
Gradient blur. | |||||
Args: | |||||
point (union[tuple, list]): 2D coordinate of the Blur center point. | |||||
kernel_num (int): Number of blur kernels. Suggested value range in [1, 8]. | |||||
center (bool): Blurred or clear at the center of a specified point. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('xx.png') | |||||
>>> img = np.array(img) | |||||
>>> number = 5 | |||||
>>> h, w = img.shape[:2] | |||||
>>> point = (int(h / 5), int(w / 5)) | |||||
>>> center = True | |||||
>>> trans = GradientBlur(point, number, center) | |||||
>>> new_img = trans(img) | |||||
""" | |||||
def __init__(self, point, kernel_num=3, center=True, auto_param=False): | |||||
super(GradientBlur).__init__() | |||||
point = check_param_multi_types('point', point, [list, tuple]) | |||||
self.auto_param = check_param_type('auto_param', auto_param, bool) | |||||
self.point = tuple(point) | |||||
self.kernel_num = check_int_positive('kernel_num', kernel_num) | |||||
self.center = check_param_type('center', center, bool) | |||||
def _auto_param(self, h, w): | |||||
self.point = (int(np.random.uniform(0, h)), int(np.random.uniform(0, w))) | |||||
self.kernel_num = np.random.randint(1, 6) | |||||
self.center = np.random.choice([True, False]) | |||||
def __call__(self, image): | |||||
""" | |||||
Args: | |||||
image (numpy.ndarray): Original image. | |||||
Returns: | |||||
numpy.ndarray, gradient blurred image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
w, h = image.shape[:2] | |||||
if self.auto_param: | |||||
self._auto_param(h, w) | |||||
mask = np.zeros(image.shape, dtype=np.uint8) | |||||
masks = [] | |||||
radius = max(w - self.point[0], self.point[0], h - self.point[1], self.point[1]) | |||||
radius = int(radius / self.kernel_num) | |||||
for i in range(self.kernel_num): | |||||
circle = cv2.circle(mask.copy(), self.point, radius * (1 + i), (1, 1, 1), -1) | |||||
masks.append(circle) | |||||
blurs = [] | |||||
for i in range(3, 3 + 2 * self.kernel_num, 2): | |||||
ksize = (i, i) | |||||
blur = cv2.GaussianBlur(image, ksize, 0) | |||||
blurs.append(blur) | |||||
dst = image.copy() | |||||
if self.center: | |||||
for i in range(self.kernel_num): | |||||
dst = masks[i] * dst + (1 - masks[i]) * blurs[i] | |||||
else: | |||||
for i in range(self.kernel_num - 1, -1, -1): | |||||
dst = masks[i] * blurs[self.kernel_num - 1 - i] + (1 - masks[i]) * dst | |||||
dst = self._original_format(dst, chw, normalized, gray3dim) | |||||
return dst.astype(ori_dtype) |
@@ -0,0 +1,251 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image corruption. | |||||
""" | |||||
import math | |||||
import numpy as np | |||||
import cv2 | |||||
from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb | |||||
from mindarmour.utils._check_param import check_param_multi_types, check_param_type | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image corruption' | |||||
class UniformNoise(_NaturalPerturb): | |||||
""" | |||||
Add uniform noise of an image. | |||||
Args: | |||||
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in | |||||
[0.001, 0.15]. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> factor = 0.1 | |||||
>>> trans = UniformNoise(factor) | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, factor=0.1, auto_param=False): | |||||
super(UniformNoise, self).__init__() | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
check_param_type('auto_param', auto_param, bool) | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0, 0.15) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
low, high = (0, 255) | |||||
weight = self.factor * (high - low) | |||||
noise = np.random.uniform(-weight, weight, size=image.shape) | |||||
trans_image = np.clip(image + noise, low, high) | |||||
trans_image = self._original_format(trans_image, chw, normalized, gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class GaussianNoise(_NaturalPerturb): | |||||
""" | |||||
Add gaussian noise of an image. | |||||
Args: | |||||
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in | |||||
[0.001, 0.15]. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> factor = 0.1 | |||||
>>> trans = GaussianNoise(factor) | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, factor=0.1, auto_param=False): | |||||
super(GaussianNoise, self).__init__() | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
check_param_type('auto_param', auto_param, bool) | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0, 0.15) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
low, high = (0, 255) | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
std = self.factor / math.sqrt(3) * (high - low) | |||||
noise = np.random.normal(scale=std, size=image.shape) | |||||
trans_image = np.clip(image + noise, low, high) | |||||
trans_image = self._original_format(trans_image, chw, normalized, gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class SaltAndPepperNoise(_NaturalPerturb): | |||||
""" | |||||
Add salt and pepper noise of an image. | |||||
Args: | |||||
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in | |||||
[0.001, 0.15]. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> factor = 0.1 | |||||
>>> trans = SaltAndPepperNoise(factor) | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, factor=0, auto_param=False): | |||||
super(SaltAndPepperNoise, self).__init__() | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
check_param_type('auto_param', auto_param, bool) | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0, 0.15) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
low, high = (0, 255) | |||||
noise = np.random.uniform(low=-1, high=1, size=(image.shape[0], image.shape[1])) | |||||
trans_image = np.copy(image) | |||||
threshold = 1 - self.factor | |||||
trans_image[noise < -threshold] = low | |||||
trans_image[noise > threshold] = high | |||||
trans_image = self._original_format(trans_image, chw, normalized, gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class NaturalNoise(_NaturalPerturb): | |||||
""" | |||||
Add natural noise to an image. | |||||
Args: | |||||
ratio (float): Noise density, the proportion of noise blocks per unit pixel area. Suggested value range in | |||||
[0.00001, 0.001]. | |||||
k_x_range (union[list, tuple]): Value range of the noise block length. | |||||
k_y_range (union[list, tuple]): Value range of the noise block width. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Examples: | |||||
>>> img = cv2.imread('xx.png') | |||||
>>> img = np.array(img) | |||||
>>> ratio = 0.0002 | |||||
>>> k_x_range = (1, 5) | |||||
>>> k_y_range = (3, 25) | |||||
>>> trans = NaturalNoise(ratio, k_x_range, k_y_range) | |||||
>>> new_img = trans(img) | |||||
""" | |||||
def __init__(self, ratio=0.0002, k_x_range=(1, 5), k_y_range=(3, 25), auto_param=False): | |||||
super(NaturalNoise).__init__() | |||||
self.ratio = check_param_type('ratio', ratio, float) | |||||
k_x_range = check_param_multi_types('k_x_range', k_x_range, [list, tuple]) | |||||
k_y_range = check_param_multi_types('k_y_range', k_y_range, [list, tuple]) | |||||
self.k_x_range = tuple(k_x_range) | |||||
self.k_y_range = tuple(k_y_range) | |||||
self.auto_param = check_param_type('auto_param', auto_param, bool) | |||||
def __call__(self, image): | |||||
""" | |||||
Add natural noise to given image. | |||||
Args: | |||||
image (numpy.ndarray): Original image. | |||||
Returns: | |||||
numpy.ndarray, image with natural noise. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
randon_range = 100 | |||||
w, h = image.shape[:2] | |||||
channel = len(np.shape(image)) | |||||
if self.auto_param: | |||||
self.ratio = np.random.uniform(0, 0.001) | |||||
self.k_x_range = (1, 0.1 * w) | |||||
self.k_y_range = (1, 0.1 * h) | |||||
for _ in range(5): | |||||
if channel == 3: | |||||
noise = np.ones((w, h, 3), dtype=np.uint8) * 255 | |||||
dst = np.ones((w, h, 3), dtype=np.uint8) * 255 | |||||
else: | |||||
noise = np.ones((w, h), dtype=np.uint8) * 255 | |||||
dst = np.ones((w, h), dtype=np.uint8) * 255 | |||||
rate = self.ratio / 5 | |||||
mask = np.random.uniform(size=(w, h)) < rate | |||||
noise[mask] = np.random.randint(0, randon_range) | |||||
k_x, k_y = np.random.randint(*self.k_x_range), np.random.randint(*self.k_y_range) | |||||
kernel = np.ones((k_x, k_y), np.uint8) | |||||
erode = cv2.erode(noise, kernel, iterations=1) | |||||
dst = erode * (erode < randon_range) + dst * (1 - erode < randon_range) | |||||
# Add black point | |||||
for _ in range(np.random.randint(math.ceil(k_x * k_y / 2))): | |||||
x = np.random.randint(-k_x, k_x) | |||||
y = np.random.randint(-k_y, k_y) | |||||
matrix = np.array([[1, 0, y], [0, 1, x]], dtype=np.float) | |||||
affine = cv2.warpAffine(noise, matrix, (h, w)) | |||||
dst = affine * (affine < randon_range) + dst * (1 - affine < randon_range) | |||||
# Add white point | |||||
for _ in range(int(k_x * k_y / 2)): | |||||
x = np.random.randint(-k_x / 2 - 1, k_x / 2 + 1) | |||||
y = np.random.randint(-k_y / 2 - 1, k_y / 2 + 1) | |||||
matrix = np.array([[1, 0, y], [0, 1, x]], dtype=np.float) | |||||
affine = cv2.warpAffine(noise, matrix, (h, w)) | |||||
white = affine < randon_range | |||||
dst[white] = 255 | |||||
mask = dst < randon_range | |||||
dst = image * (1 - mask) + dst * mask | |||||
dst = self._original_format(dst, chw, normalized, gray3dim) | |||||
return dst.astype(ori_dtype) |
@@ -0,0 +1,287 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image luminance. | |||||
""" | |||||
import math | |||||
import numpy as np | |||||
import cv2 | |||||
from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb | |||||
from mindarmour.utils._check_param import check_param_multi_types, check_param_in_range, check_param_type, \ | |||||
check_value_non_negative | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image Luminance' | |||||
class Contrast(_NaturalPerturb): | |||||
""" | |||||
Contrast of an image. | |||||
Args: | |||||
alpha (Union[float, int]): Control the contrast of an image. :math:`out_image = in_image*alpha+beta`. | |||||
Suggested value range in [0.2, 2]. | |||||
beta (Union[float, int]): Delta added to alpha. Default: 0. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> alpha = 0.1 | |||||
>>> beta = 1 | |||||
>>> trans = Contrast(alpha, beta) | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, alpha=1, beta=0, auto_param=False): | |||||
super(Contrast, self).__init__() | |||||
self.alpha = check_param_multi_types('factor', alpha, [int, float]) | |||||
self.beta = check_param_multi_types('factor', beta, [int, float]) | |||||
auto_param = check_param_type('auto_param', auto_param, bool) | |||||
if auto_param: | |||||
self.alpha = np.random.uniform(0.2, 2) | |||||
self.beta = np.random.uniform(-20, 20) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
dst = cv2.convertScaleAbs(image, alpha=self.alpha, beta=self.beta) | |||||
dst = self._original_format(dst, chw, normalized, gray3dim) | |||||
return dst.astype(ori_dtype) | |||||
def _circle_gradient_mask(img_src, color_start, color_end, scope=0.5, point=None): | |||||
""" | |||||
Generate circle gradient mask. | |||||
Args: | |||||
img_src (numpy.ndarray): Source image. | |||||
color_start (union([tuple, list])): Color of circle gradient center. | |||||
color_end (union([tuple, list])): Color of circle gradient edge. | |||||
scope (float): Range of the gradient. A larger value indicates a larger gradient range. | |||||
point (union([tuple, list]): Gradient center point. | |||||
Returns: | |||||
numpy.ndarray, gradients mask. | |||||
""" | |||||
if not isinstance(img_src, np.ndarray): | |||||
raise TypeError('`src` must be numpy.ndarray type, but got {0}.'.format(type(img_src))) | |||||
shape = img_src.shape | |||||
height, width = shape[:2] | |||||
rgb = False | |||||
if len(shape) == 3: | |||||
rgb = True | |||||
if point is None: | |||||
point = (height // 2, width // 2) | |||||
x, y = point | |||||
# upper left | |||||
bound_upper_left = math.ceil(math.sqrt(x ** 2 + y ** 2)) | |||||
# upper right | |||||
bound_upper_right = math.ceil(math.sqrt(height ** 2 + (width - y) ** 2)) | |||||
# lower left | |||||
bound_lower_left = math.ceil(math.sqrt((height - x) ** 2 + y ** 2)) | |||||
# lower right | |||||
bound_lower_right = math.ceil(math.sqrt((height - x) ** 2 + (width - y) ** 2)) | |||||
radius = max(bound_lower_left, bound_lower_right, bound_upper_left, bound_upper_right) * scope | |||||
img_grad = np.ones_like(img_src, dtype=np.uint8) * max(color_end) | |||||
# opencv use BGR format | |||||
grad_b = float(color_end[0] - color_start[0]) / radius | |||||
grad_g = float(color_end[1] - color_start[1]) / radius | |||||
grad_r = float(color_end[2] - color_start[2]) / radius | |||||
for i in range(height): | |||||
for j in range(width): | |||||
distance = math.ceil(math.sqrt((x - i) ** 2 + (y - j) ** 2)) | |||||
if distance >= radius: | |||||
continue | |||||
if rgb: | |||||
img_grad[i, j, 0] = color_start[0] + distance * grad_b | |||||
img_grad[i, j, 1] = color_start[1] + distance * grad_g | |||||
img_grad[i, j, 2] = color_start[2] + distance * grad_r | |||||
else: | |||||
img_grad[i, j] = color_start[0] + distance * grad_b | |||||
return img_grad.astype(np.uint8) | |||||
def _line_gradient_mask(image, start_pos=None, start_color=(0, 0, 0), end_color=(255, 255, 255), mode='horizontal'): | |||||
""" | |||||
Generate liner gradient mask. | |||||
Args: | |||||
image (numpy.ndarray): Original image. | |||||
start_pos (union[tuple, list]): 2D coordinate of gradient center. | |||||
start_color (union([tuple, list])): Color of circle gradient center. | |||||
end_color (union([tuple, list])): Color of circle gradient edge. | |||||
mode (str): Direction of gradient. Optional value is 'vertical' or 'horizontal'. | |||||
Returns: | |||||
numpy.ndarray, gradients mask. | |||||
""" | |||||
shape = image.shape | |||||
h, w = shape[:2] | |||||
rgb = False | |||||
if len(shape) == 3: | |||||
rgb = True | |||||
if start_pos is None: | |||||
start_pos = 0.5 | |||||
else: | |||||
if mode == 'horizontal': | |||||
if start_pos[0] > h: | |||||
start_pos = 1 | |||||
else: | |||||
start_pos = start_pos[0] / h | |||||
else: | |||||
if start_pos[1] > w: | |||||
start_pos = 1 | |||||
else: | |||||
start_pos = start_pos[1] / w | |||||
start_color = np.array(start_color) | |||||
end_color = np.array(end_color) | |||||
if mode == 'horizontal': | |||||
w_l = int(w * start_pos) | |||||
w_r = w - w_l | |||||
if w_l > w_r: | |||||
r_end_color = (end_color - start_color) / start_pos * (1 - start_pos) + start_color | |||||
left = np.linspace(end_color, start_color, w_l) | |||||
right = np.linspace(start_color, r_end_color, w_r) | |||||
else: | |||||
l_end_color = (end_color - start_color) / (1 - start_pos) * start_pos + start_color | |||||
left = np.linspace(l_end_color, start_color, w_l) | |||||
right = np.linspace(start_color, end_color, w_r) | |||||
line = np.concatenate((left, right), axis=0) | |||||
mask = np.reshape(np.tile(line, (h, 1)), (h, w, 3)) | |||||
else: | |||||
# 'vertical' | |||||
h_t = int(h * start_pos) | |||||
h_b = h - h_t | |||||
if h_t > h_b: | |||||
b_end_color = (end_color - start_color) / start_pos * (1 - start_pos) + start_color | |||||
top = np.linspace(end_color, start_color, h_t) | |||||
bottom = np.linspace(start_color, b_end_color, h_b) | |||||
else: | |||||
t_end_color = (end_color - start_color) / (1 - start_pos) * start_pos + start_color | |||||
top = np.linspace(t_end_color, start_color, h_t) | |||||
bottom = np.linspace(start_color, end_color, h_b) | |||||
line = np.concatenate((top, bottom), axis=0) | |||||
mask = np.reshape(np.tile(line, (w, 1)), (w, h, 3)) | |||||
mask = np.transpose(mask, [1, 0, 2]) | |||||
if not rgb: | |||||
mask = mask[:, :, 0] | |||||
return mask.astype(np.uint8) | |||||
class GradientLuminance(_NaturalPerturb): | |||||
""" | |||||
Gradient adjusts the luminance of picture. | |||||
Args: | |||||
color_start (union[tuple, list]): Color of gradient center. Default:(0, 0, 0). | |||||
color_end (union[tuple, list]): Color of gradient edge. Default:(255, 255, 255). | |||||
start_point (union[tuple, list]): 2D coordinate of gradient center. | |||||
scope (float): Range of the gradient. A larger value indicates a larger gradient range. Default: 0.3. | |||||
pattern (str): Dark or light, this value must be in ['light', 'dark']. | |||||
bright_rate (float): Control brightness of . A larger value indicates a larger gradient range. If parameter | |||||
'pattern' is 'light', Suggested value range in [0.1, 0.7], if parameter 'pattern' is 'dark', Suggested value | |||||
range in [0.1, 0.9]. | |||||
mode (str): Gradient mode, value must be in ['circle', 'horizontal', 'vertical']. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Examples: | |||||
>>> img = cv2.imread('x.png') | |||||
>>> height, width = img.shape[:2] | |||||
>>> point = (height // 4, width // 2) | |||||
>>> start = (255, 255, 255) | |||||
>>> end = (0, 0, 0) | |||||
>>> scope = 0.3 | |||||
>>> pattern='light' | |||||
>>> bright_rate = 0.3 | |||||
>>> trans = GradientLuminance(start, end, point, scope, pattern, bright_rate, mode='circle') | |||||
>>> img_new = trans(img) | |||||
""" | |||||
def __init__(self, color_start=(0, 0, 0), color_end=(255, 255, 255), start_point=(10, 10), scope=0.5, | |||||
pattern='light', bright_rate=0.3, mode='circle', auto_param=False): | |||||
super(GradientLuminance, self).__init__() | |||||
self.color_start = check_param_multi_types('color_start', color_start, [list, tuple]) | |||||
self.color_end = check_param_multi_types('color_end', color_end, [list, tuple]) | |||||
self.start_point = check_param_multi_types('start_point', start_point, [list, tuple]) | |||||
self.scope = check_value_non_negative('scope', scope) | |||||
self.bright_rate = check_param_type('bright_rate', bright_rate, float) | |||||
self.bright_rate = check_param_in_range('bright_rate', bright_rate, 0, 1) | |||||
self.auto_param = check_param_type('auto_param', auto_param, bool) | |||||
if pattern in ['light', 'dark']: | |||||
self.pattern = pattern | |||||
else: | |||||
msg = "Value of param pattern must be in ['light', 'dark']" | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if mode in ['circle', 'horizontal', 'vertical']: | |||||
self.mode = mode | |||||
else: | |||||
msg = "Value of param mode must be in ['circle', 'horizontal', 'vertical']" | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
def _set_auto_param(self, w, h): | |||||
self.color_start = (np.random.uniform(0, 255),) * 3 | |||||
self.color_end = (np.random.uniform(0, 255),) * 3 | |||||
self.start_point = (np.random.uniform(0, w), np.random.uniform(0, h)) | |||||
self.scope = np.random.uniform(0, 1) | |||||
self.bright_rate = np.random.uniform(0.1, 0.9) | |||||
self.pattern = np.random.choice(['light', 'dark']) | |||||
self.mode = np.random.choice(['circle', 'horizontal', 'vertical']) | |||||
def __call__(self, image): | |||||
""" | |||||
Gradient adjusts the luminance of picture. | |||||
Args: | |||||
image (numpy.ndarray): Original image. | |||||
Returns: | |||||
numpy.ndarray, image with perlin noise. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
w, h = image.shape[:2] | |||||
if self.auto_param: | |||||
self._set_auto_param(w, h) | |||||
if self.mode == 'circle': | |||||
mask = _circle_gradient_mask(image, self.color_start, self.color_end, self.scope, self.start_point) | |||||
else: | |||||
mask = _line_gradient_mask(image, self.start_point, self.color_start, self.color_end, mode=self.mode) | |||||
if self.pattern == 'light': | |||||
img_new = cv2.addWeighted(image, 1, mask, self.bright_rate, 0.0) | |||||
else: | |||||
img_new = cv2.addWeighted(image, self.bright_rate, mask, 1 - self.bright_rate, 0.0) | |||||
img_new = self._original_format(img_new, chw, normalized, gray3dim) | |||||
return img_new.astype(ori_dtype) |
@@ -0,0 +1,159 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Base class for image natural perturbation. | |||||
""" | |||||
import numpy as np | |||||
from mindspore.dataset.vision.py_transforms_util import is_numpy, hwc_to_chw | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image Transformation' | |||||
def _chw_to_hwc(img): | |||||
""" | |||||
Transpose the input image; shape (C, H, W) to shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be converted. | |||||
Returns: | |||||
img (numpy.ndarray), Converted image. | |||||
""" | |||||
if is_numpy(img): | |||||
return img.transpose(1, 2, 0).copy() | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def _is_hwc(img): | |||||
""" | |||||
Check if the input image is shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is shape (H, W, C). | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def _is_chw(img): | |||||
""" | |||||
Check if the input image is shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is shape (H, W, C). | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def _is_rgb(img): | |||||
""" | |||||
Check if the input image is RGB. | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is RGB. | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3): | |||||
return True | |||||
return False | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def _is_normalized(img): | |||||
""" | |||||
Check if the input image is normalized between 0 to 1. | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is normalized between 0 to 1. | |||||
""" | |||||
if is_numpy(img): | |||||
minimal = np.min(img) | |||||
maximum = np.max(img) | |||||
if minimal >= 0 and maximum <= 1: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||||
class _NaturalPerturb: | |||||
""" | |||||
The abstract base class for all image natural perturbation classes. | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
def _check(self, image): | |||||
""" Check image format. If input image is RGB and its shape | |||||
is (C, H, W), it will be transposed to (H, W, C). If the value | |||||
of the image is not normalized , it will be rescaled between 0 to 255.""" | |||||
rgb = _is_rgb(image) | |||||
chw = False | |||||
gray3dim = False | |||||
normalized = _is_normalized(image) | |||||
if rgb: | |||||
chw = _is_chw(image) | |||||
if chw: | |||||
image = _chw_to_hwc(image) | |||||
else: | |||||
image = image | |||||
else: | |||||
if len(np.shape(image)) == 3: | |||||
gray3dim = True | |||||
image = image[0] | |||||
else: | |||||
image = image | |||||
if normalized: | |||||
image = image * 255 | |||||
return rgb, chw, normalized, gray3dim, np.uint8(image) | |||||
def _original_format(self, image, chw, normalized, gray3dim): | |||||
""" Return image with original format. """ | |||||
if not is_numpy(image): | |||||
image = np.array(image) | |||||
if chw: | |||||
image = hwc_to_chw(image) | |||||
if normalized: | |||||
image = image / 255 | |||||
if gray3dim: | |||||
image = np.expand_dims(image, 0) | |||||
return image | |||||
def __call__(self, image): | |||||
pass |
@@ -0,0 +1,365 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image transformation. | |||||
""" | |||||
import math | |||||
import numpy as np | |||||
import cv2 | |||||
from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb | |||||
from mindarmour.utils._check_param import check_param_multi_types, check_param_type, check_value_non_negative | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image Transformation' | |||||
class Translate(_NaturalPerturb): | |||||
""" | |||||
Translate an image. | |||||
Args: | |||||
x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length. Suggested value range | |||||
in [-0.1, 0.1]. | |||||
y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide. Suggested value range | |||||
in [-0.1, 0.1]. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> x_bias = 0.1 | |||||
>>> y_bias = 0.1 | |||||
>>> trans = Translate(x_bias, y_bias) | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, x_bias=0, y_bias=0, auto_param=False): | |||||
super(Translate, self).__init__() | |||||
self.x_bias = check_param_multi_types('x_bias', x_bias, [int, float]) | |||||
self.y_bias = check_param_multi_types('y_bias', y_bias, [int, float]) | |||||
if auto_param: | |||||
self.x_bias = np.random.uniform(-0.1, 0.1) | |||||
self.y_bias = np.random.uniform(-0.1, 0.1) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
h, w = image.shape[:2] | |||||
matrix = np.array([[1, 0, self.x_bias * w], [0, 1, self.y_bias * h]], dtype=np.float) | |||||
new_img = cv2.warpAffine(image, matrix, (w, h)) | |||||
new_img = self._original_format(new_img, chw, normalized, gray3dim) | |||||
return new_img.astype(ori_dtype) | |||||
class Scale(_NaturalPerturb): | |||||
""" | |||||
Scale an image in the middle. | |||||
Args: | |||||
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. Suggested value range in [0.5, 1] and | |||||
abs(factor_y - factor_x) < 0.5. | |||||
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. Suggested value range in [0.5, 1] and | |||||
abs(factor_y - factor_x) < 0.5. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> factor_x = 0.7 | |||||
>>> factor_y = 0.6 | |||||
>>> trans = Scale(factor_x, factor_y) | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, factor_x=1, factor_y=1, auto_param=False): | |||||
super(Scale, self).__init__() | |||||
self.factor_x = check_param_multi_types('factor_x', factor_x, [int, float]) | |||||
self.factor_y = check_param_multi_types('factor_y', factor_y, [int, float]) | |||||
auto_param = check_param_type('auto_param', auto_param, bool) | |||||
if auto_param: | |||||
self.factor_x = np.random.uniform(0.5, 1) | |||||
self.factor_y = np.random.uniform(0.5, 1) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
h, w = image.shape[:2] | |||||
matrix = np.array([[self.factor_x, 0, 0], [0, self.factor_y, 0]], dtype=np.float) | |||||
new_img = cv2.warpAffine(image, matrix, (w, h)) | |||||
new_img = self._original_format(new_img, chw, normalized, gray3dim) | |||||
return new_img.astype(ori_dtype) | |||||
class Shear(_NaturalPerturb): | |||||
""" | |||||
Shear an image, for each pixel (x, y) in the sheared image, the new value is taken from a position | |||||
(x+factor_x*y, factor_y*x+y) in the origin image. Then the sheared image will be rescaled to fit original size. | |||||
Args: | |||||
factor (Union[float, int]): Shear rate in shear direction. Suggested value range in [0.05, 0.5]. | |||||
direction (str): Direction of deformation. Optional value is 'vertical' or 'horizontal'. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> factor = 0.2 | |||||
>>> trans = Shear(factor, direction='horizontal') | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, factor=0.2, direction='horizontal', auto_param=False): | |||||
super(Shear, self).__init__() | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
if direction not in ['horizontal', 'vertical']: | |||||
msg = "'direction must be in ['horizontal', 'vertical'], but got {}".format(direction) | |||||
raise ValueError(msg) | |||||
self.direction = direction | |||||
auto_param = check_param_type('auto_params', auto_param, bool) | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0.05, 0.5) | |||||
self.direction = np.random.choice(['horizontal', 'vertical']) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
h, w = image.shape[:2] | |||||
if self.direction == 'horizontal': | |||||
matrix = np.array([[1, self.factor, 0], [0, 1, 0]], dtype=np.float) | |||||
nw = int(w + self.factor * h) | |||||
nh = h | |||||
else: | |||||
matrix = np.array([[1, 0, 0], [self.factor, 1, 0]], dtype=np.float) | |||||
nw = w | |||||
nh = int(h + self.factor * w) | |||||
new_img = cv2.warpAffine(image, matrix, (nw, nh)) | |||||
new_img = cv2.resize(new_img, (w, h)) | |||||
new_img = self._original_format(new_img, chw, normalized, gray3dim) | |||||
return new_img.astype(ori_dtype) | |||||
class Rotate(_NaturalPerturb): | |||||
""" | |||||
Rotate an image of counter clockwise around its center. | |||||
Args: | |||||
angle (Union[float, int]): Degrees of counter clockwise. Suggested value range in [-60, 60]. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> angle = 20 | |||||
>>> trans = Rotate(angle) | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, angle=20, auto_param=False): | |||||
super(Rotate, self).__init__() | |||||
self.angle = check_param_multi_types('angle', angle, [int, float]) | |||||
auto_param = check_param_type('auto_param', auto_param, bool) | |||||
if auto_param: | |||||
self.angle = np.random.uniform(0, 360) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, rotated image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
h, w = image.shape[:2] | |||||
center = (w // 2, h // 2) | |||||
matrix = cv2.getRotationMatrix2D(center, -self.angle, 1.0) | |||||
cos = np.abs(matrix[0, 0]) | |||||
sin = np.abs(matrix[0, 1]) | |||||
# Calculate new edge after rotated | |||||
nw = int((h * sin) + (w * cos)) | |||||
nh = int((h * cos) + (w * sin)) | |||||
# Adjust move distance of rotate matrix. | |||||
matrix[0, 2] += (nw / 2) - center[0] | |||||
matrix[1, 2] += (nh / 2) - center[1] | |||||
rotate = cv2.warpAffine(image, matrix, (nw, nh)) | |||||
rotate = cv2.resize(rotate, (w, h)) | |||||
new_img = self._original_format(rotate, chw, normalized, gray3dim) | |||||
return new_img.astype(ori_dtype) | |||||
class Perspective(_NaturalPerturb): | |||||
""" | |||||
Perform perspective transformation on a given picture. | |||||
Args: | |||||
ori_pos (list): Four points in original image. | |||||
dst_pos (list): The point coordinates of the 4 points in ori_pos after perspective transformation. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Example: | |||||
>>> img = cv2.imread('1.png') | |||||
>>> img = np.array(img) | |||||
>>> ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]] | |||||
>>> dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]] | |||||
>>> trans = Perspective(ori_pos, dst_pos) | |||||
>>> dst = trans(img) | |||||
""" | |||||
def __init__(self, ori_pos, dst_pos, auto_param=False): | |||||
super(Perspective, self).__init__() | |||||
ori_pos = check_param_type('ori_pos', ori_pos, list) | |||||
dst_pos = check_param_type('dst_pos', dst_pos, list) | |||||
self.ori_pos = np.float32(ori_pos) | |||||
self.dst_pos = np.float32(dst_pos) | |||||
self.auto_param = check_param_type('auto_param', auto_param, bool) | |||||
def _set_auto_param(self, w, h): | |||||
self.ori_pos = [[h * 0.25, w * 0.25], [h * 0.25, w * 0.75], [h * 0.75, w * 0.25], [h * 0.75, w * 0.75]] | |||||
self.dst_pos = [[np.random.uniform(0, h * 0.5), np.random.uniform(0, w * 0.5)], | |||||
[np.random.uniform(0, h * 0.5), np.random.uniform(w * 0.5, w)], | |||||
[np.random.uniform(h * 0.5, h), np.random.uniform(0, w * 0.5)], | |||||
[np.random.uniform(h * 0.5, h), np.random.uniform(w * 0.5, w)]] | |||||
self.ori_pos = np.float32(self.ori_pos) | |||||
self.dst_pos = np.float32(self.dst_pos) | |||||
def __call__(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
h, w = image.shape[:2] | |||||
if self.auto_param: | |||||
self._set_auto_param(w, h) | |||||
matrix = cv2.getPerspectiveTransform(self.ori_pos, self.dst_pos) | |||||
new_img = cv2.warpPerspective(image, matrix, (w, h)) | |||||
new_img = self._original_format(new_img, chw, normalized, gray3dim) | |||||
return new_img.astype(ori_dtype) | |||||
class Curve(_NaturalPerturb): | |||||
""" | |||||
Curve picture using sin method. | |||||
Args: | |||||
curves (union[float, int]): Divide width to curves of `2*math.pi`, which means how many curve cycles. Suggested | |||||
value range in [0.1. 5]. | |||||
depth (union[float, int]): Amplitude of sin method. Suggested value not exceed 1/10 of the length of the | |||||
picture. | |||||
mode (str): Direction of deformation. Optional value is 'vertical' or 'horizontal'. | |||||
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. | |||||
Examples: | |||||
>>> img = cv2.imread('x.png') | |||||
>>> curves =1 | |||||
>>> depth = 10 | |||||
>>> trans = Curve(curves, depth, mode='vertical') | |||||
>>> img_new = trans(img) | |||||
""" | |||||
def __init__(self, curves=3, depth=10, mode='vertical', auto_param=False): | |||||
super(Curve).__init__() | |||||
self.curves = check_value_non_negative('curves', curves) | |||||
self.depth = check_value_non_negative('depth', depth) | |||||
if mode in ['vertical', 'horizontal']: | |||||
self.mode = mode | |||||
else: | |||||
msg = "Value of param mode must be in ['vertical', 'horizontal']" | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
self.auto_param = check_param_type('auto_param', auto_param, bool) | |||||
def _set_auto_params(self, height, width): | |||||
if self.auto_param: | |||||
self.curves = np.random.uniform(1, 5) | |||||
self.mode = np.random.choice(['vertical', 'horizontal']) | |||||
if self.mode == 'vertical': | |||||
self.depth = np.random.uniform(1, 0.1 * width) | |||||
else: | |||||
self.depth = np.random.uniform(1, 0.1 * height) | |||||
def __call__(self, image): | |||||
""" | |||||
Curve picture using sin method. | |||||
Args: | |||||
image (numpy.ndarray): Original image. | |||||
Returns: | |||||
numpy.ndarray, curved image. | |||||
""" | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
shape = image.shape | |||||
height, width = shape[:2] | |||||
if self.mode == 'vertical': | |||||
if len(shape) == 3: | |||||
image = np.transpose(image, [1, 0, 2]) | |||||
else: | |||||
image = np.transpose(image, [1, 0]) | |||||
src_x = np.zeros((height, width), np.float32) | |||||
src_y = np.zeros((height, width), np.float32) | |||||
for y in range(height): | |||||
for x in range(width): | |||||
src_x[y, x] = x | |||||
src_y[y, x] = y + self.depth * math.sin(x / (width / self.curves / 2 / math.pi)) | |||||
img_new = cv2.remap(image, src_x, src_y, cv2.INTER_LINEAR) | |||||
if self.mode == 'vertical': | |||||
if len(shape) == 3: | |||||
img_new = np.transpose(img_new, [1, 0, 2]) | |||||
else: | |||||
img_new = np.transpose(image, [1, 0]) | |||||
new_img = self._original_format(img_new, chw, normalized, gray3dim) | |||||
return new_img.astype(ori_dtype) |
@@ -1,4 +1,5 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | # Copyright 2021 Huawei Technologies Co., Ltd | ||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
# you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
# You may obtain a copy of the License at | # You may obtain a copy of the License at | ||||
@@ -183,7 +183,7 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels): | |||||
def check_equal_length(para_name1, value1, para_name2, value2): | def check_equal_length(para_name1, value1, para_name2, value2): | ||||
"""Check weather the two parameters have equal length.""" | """Check weather the two parameters have equal length.""" | ||||
if len(value1) != len(value2): | if len(value1) != len(value2): | ||||
msg = 'The dimension of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'\ | |||||
msg = 'The dimension of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}' \ | |||||
.format(para_name1, para_name2, len(value1), len(value2)) | .format(para_name1, para_name2, len(value1), len(value2)) | ||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
@@ -193,7 +193,7 @@ def check_equal_length(para_name1, value1, para_name2, value2): | |||||
def check_equal_shape(para_name1, value1, para_name2, value2): | def check_equal_shape(para_name1, value1, para_name2, value2): | ||||
"""Check weather the two parameters have equal shape.""" | """Check weather the two parameters have equal shape.""" | ||||
if value1.shape != value2.shape: | if value1.shape != value2.shape: | ||||
msg = 'The shape of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'.\ | |||||
msg = 'The shape of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'. \ | |||||
format(para_name1, para_name2, value1.shape, value2.shape) | format(para_name1, para_name2, value1.shape, value2.shape) | ||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
@@ -204,7 +204,7 @@ def check_norm_level(norm_level): | |||||
"""Check norm_level of regularization.""" | """Check norm_level of regularization.""" | ||||
if not isinstance(norm_level, (int, str)): | if not isinstance(norm_level, (int, str)): | ||||
msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level)) | msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level)) | ||||
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', np.inf] | |||||
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf] | |||||
if norm_level not in accept_norm: | if norm_level not in accept_norm: | ||||
msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level) | msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level) | ||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
@@ -224,8 +224,7 @@ def normalize_value(value, norm_level): | |||||
numpy.ndarray, normalized value. | numpy.ndarray, normalized value. | ||||
Raises: | Raises: | ||||
NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', | |||||
'inf', 'l1', 'l2'] | |||||
NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', 'inf', 'l1', 'l2'] | |||||
""" | """ | ||||
norm_level = check_norm_level(norm_level) | norm_level = check_norm_level(norm_level) | ||||
ori_shape = value.shape | ori_shape = value.shape | ||||
@@ -237,7 +236,7 @@ def normalize_value(value, norm_level): | |||||
elif norm_level in (2, '2', 'l2'): | elif norm_level in (2, '2', 'l2'): | ||||
norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div | norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div | ||||
norm_value = value_reshape / norm | norm_value = value_reshape / norm | ||||
elif norm_level in (np.inf, 'inf'): | |||||
elif norm_level in (np.inf, 'inf', 'np.inf', 'linf'): | |||||
norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div | norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div | ||||
norm_value = value_reshape / norm | norm_value = value_reshape / norm | ||||
else: | else: | ||||
@@ -75,11 +75,11 @@ def test_lenet_mnist_coverage_cpu(): | |||||
model = Model(net) | model = Model(net) | ||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||||
training_data = (np.random.random((10000, 10)) * 20).astype(np.float32) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||||
test_data = (np.random.random((2000, 10)) * 20).astype(np.float32) | |||||
test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | test_labels = np.random.randint(0, 10, 2000).astype(np.int32) | ||||
nc = NeuronCoverage(model, threshold=0.1) | nc = NeuronCoverage(model, threshold=0.1) | ||||
@@ -118,6 +118,7 @@ def test_lenet_mnist_coverage_cpu(): | |||||
print('NC of adv data is: ', nc_metric) | print('NC of adv data is: ', nc_metric) | ||||
print('TKNC of adv data is: ', tknc_metrics) | print('TKNC of adv data is: ', tknc_metrics) | ||||
@pytest.mark.level0 | @pytest.mark.level0 | ||||
@pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
@pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
@@ -130,11 +131,11 @@ def test_lenet_mnist_coverage_ascend(): | |||||
model = Model(net) | model = Model(net) | ||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | |||||
training_data = (np.random.random((10000, 10)) * 20).astype(np.float32) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
test_data = (np.random.random((2000, 10))*20).astype(np.float32) | |||||
test_data = (np.random.random((2000, 10)) * 20).astype(np.float32) | |||||
nc = NeuronCoverage(model, threshold=0.1) | nc = NeuronCoverage(model, threshold=0.1) | ||||
nc_metric = nc.get_metrics(test_data) | nc_metric = nc.get_metrics(test_data) | ||||
@@ -99,15 +99,17 @@ def test_fuzzing_ascend(): | |||||
model = Model(net) | model = Model(net) | ||||
batch_size = 8 | batch_size = 8 | ||||
num_classe = 10 | num_classe = 10 | ||||
mutate_config = [{'method': 'Blur', | |||||
'params': {'auto_param': [True]}}, | |||||
mutate_config = [{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], | |||||
'auto_param': [True, False]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | {'method': 'Contrast', | ||||
'params': {'factor': [2, 1]}}, | |||||
{'method': 'Translate', | |||||
'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}}, | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | {'method': 'FGSM', | ||||
'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} | |||||
] | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
# fuzz test with original test data | # fuzz test with original test data | ||||
@@ -142,15 +144,17 @@ def test_fuzzing_cpu(): | |||||
model = Model(net) | model = Model(net) | ||||
batch_size = 8 | batch_size = 8 | ||||
num_classe = 10 | num_classe = 10 | ||||
mutate_config = [{'method': 'Blur', | |||||
'params': {'auto_param': [True]}}, | |||||
mutate_config = [{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], | |||||
'auto_param': [True, False]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | {'method': 'Contrast', | ||||
'params': {'factor': [2, 1]}}, | |||||
{'method': 'Translate', | |||||
'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}}, | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | {'method': 'FGSM', | ||||
'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} | |||||
] | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
@@ -1,126 +0,0 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image transform test. | |||||
""" | |||||
import numpy as np | |||||
import pytest | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.fuzz_testing.image_transform import Contrast, Brightness, \ | |||||
Blur, Noise, Translate, Scale, Shear, Rotate | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image transform test' | |||||
LOGGER.set_level('INFO') | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_contrast(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Contrast() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_brightness(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Brightness() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_blur(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Blur() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_noise(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Noise() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_translate(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Translate() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_shear(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Shear() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_scale(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Scale() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_rotate(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Rotate() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) |
@@ -0,0 +1,577 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
"""Example for natural robustness methods.""" | |||||
import pytest | |||||
import numpy as np | |||||
from mindspore import context | |||||
from mindarmour.natural_robustness import Translate, Curve, Perspective, Scale, Shear, Rotate, SaltAndPepperNoise, \ | |||||
NaturalNoise, GaussianNoise, UniformNoise, MotionBlur, GaussianBlur, GradientBlur, Contrast, GradientLuminance | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_perspective(): | |||||
""" | |||||
Feature: Test image perspective. | |||||
Description: Image will be transform for given perspective projection. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]] | |||||
dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]] | |||||
trans = Perspective(ori_pos, dst_pos) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_uniform_noise(): | |||||
""" | |||||
Feature: Test image uniform noise. | |||||
Description: Add uniform image in image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = UniformNoise(factor=0.1) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_gaussian_noise(): | |||||
""" | |||||
Feature: Test image gaussian noise. | |||||
Description: Add gaussian image in image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = GaussianNoise(factor=0.1) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_contrast(): | |||||
""" | |||||
Feature: Test image contrast. | |||||
Description: Adjust image contrast. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Contrast(alpha=0.3, beta=0) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_gaussian_blur(): | |||||
""" | |||||
Feature: Test image gaussian blur. | |||||
Description: Add gaussian blur to image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = GaussianBlur(ksize=5) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_salt_and_pepper_noise(): | |||||
""" | |||||
Feature: Test image salt and pepper noise. | |||||
Description: Add salt and pepper to image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = SaltAndPepperNoise(factor=0.01) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_translate(): | |||||
""" | |||||
Feature: Test image translate. | |||||
Description: Translate an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Translate(x_bias=0.1, y_bias=0.1) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_scale(): | |||||
""" | |||||
Feature: Test image scale. | |||||
Description: Scale an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Scale(factor_x=0.7, factor_y=0.7) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_shear(): | |||||
""" | |||||
Feature: Test image shear. | |||||
Description: Shear an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Shear(factor=0.2) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_rotate(): | |||||
""" | |||||
Feature: Test image rotate. | |||||
Description: Rotate an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Rotate(angle=20) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_curve(): | |||||
""" | |||||
Feature: Test image curve. | |||||
Description: Transform an image with curve. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Curve(curves=1.5, depth=1.5, mode='horizontal') | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_natural_noise(): | |||||
""" | |||||
Feature: Test natural noise. | |||||
Description: Add natural noise to an. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_gradient_luminance(): | |||||
""" | |||||
Feature: Test gradient luminance. | |||||
Description: Adjust image luminance. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
height, width = image.shape[:2] | |||||
point = (height // 4, width // 2) | |||||
start = (255, 255, 255) | |||||
end = (0, 0, 0) | |||||
scope = 0.3 | |||||
bright_rate = 0.4 | |||||
trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate, | |||||
mode='horizontal') | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_motion_blur(): | |||||
""" | |||||
Feature: Test motion blur. | |||||
Description: Add motion blur to an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
angle = -10.5 | |||||
i = 3 | |||||
trans = MotionBlur(degree=i, angle=angle) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_gradient_blur(): | |||||
""" | |||||
Feature: Test gradient blur. | |||||
Description: Add gradient blur to an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
image = np.random.random((32, 32, 3)) | |||||
number = 10 | |||||
h, w = image.shape[:2] | |||||
point = (int(h / 5), int(w / 5)) | |||||
center = False | |||||
trans = GradientBlur(point, number, center) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_perspective_ascend(): | |||||
""" | |||||
Feature: Test image perspective. | |||||
Description: Image will be transform for given perspective projection. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]] | |||||
dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]] | |||||
trans = Perspective(ori_pos, dst_pos) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_uniform_noise_ascend(): | |||||
""" | |||||
Feature: Test image uniform noise. | |||||
Description: Add uniform image in image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = UniformNoise(factor=0.1) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_gaussian_noise_ascend(): | |||||
""" | |||||
Feature: Test image gaussian noise. | |||||
Description: Add gaussian image in image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = GaussianNoise(factor=0.1) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_contrast_ascend(): | |||||
""" | |||||
Feature: Test image contrast. | |||||
Description: Adjust image contrast. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Contrast(alpha=0.3, beta=0) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_gaussian_blur_ascend(): | |||||
""" | |||||
Feature: Test image gaussian blur. | |||||
Description: Add gaussian blur to image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = GaussianBlur(ksize=5) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_salt_and_pepper_noise_ascend(): | |||||
""" | |||||
Feature: Test image salt and pepper noise. | |||||
Description: Add salt and pepper to image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = SaltAndPepperNoise(factor=0.01) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_translate_ascend(): | |||||
""" | |||||
Feature: Test image translate. | |||||
Description: Translate an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Translate(x_bias=0.1, y_bias=0.1) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_ascend_mindarmour | |||||
def test_scale_ascend(): | |||||
""" | |||||
Feature: Test image scale. | |||||
Description: Scale an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Scale(factor_x=0.7, factor_y=0.7) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_shear_ascend(): | |||||
""" | |||||
Feature: Test image shear. | |||||
Description: Shear an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Shear(factor=0.2) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_rotate_ascend(): | |||||
""" | |||||
Feature: Test image rotate. | |||||
Description: Rotate an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Rotate(angle=20) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_curve_ascend(): | |||||
""" | |||||
Feature: Test image curve. | |||||
Description: Transform an image with curve. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = Curve(curves=1.5, depth=1.5, mode='horizontal') | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_natural_noise_ascend(): | |||||
""" | |||||
Feature: Test natural noise. | |||||
Description: Add natural noise to an. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_gradient_luminance_ascend(): | |||||
""" | |||||
Feature: Test gradient luminance. | |||||
Description: Adjust image luminance. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
height, width = image.shape[:2] | |||||
point = (height // 4, width // 2) | |||||
start = (255, 255, 255) | |||||
end = (0, 0, 0) | |||||
scope = 0.3 | |||||
bright_rate = 0.4 | |||||
trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate, | |||||
mode='horizontal') | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_motion_blur_ascend(): | |||||
""" | |||||
Feature: Test motion blur. | |||||
Description: Add motion blur to an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
angle = -10.5 | |||||
i = 3 | |||||
trans = MotionBlur(degree=i, angle=angle) | |||||
dst = trans(image) | |||||
print(dst) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_gradient_blur_ascend(): | |||||
""" | |||||
Feature: Test gradient blur. | |||||
Description: Add gradient blur to an image. | |||||
Expectation: success. | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
image = np.random.random((32, 32, 3)) | |||||
number = 10 | |||||
h, w = image.shape[:2] | |||||
point = (int(h / 5), int(w / 5)) | |||||
center = False | |||||
trans = GradientBlur(point, number, center) | |||||
dst = trans(image) | |||||
print(dst) |