diff --git a/.gitignore b/.gitignore index ef64633..f255ef9 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,6 @@ example/mnist_demo/model/ example/cifar_demo/model/ example/dog_cat_demo/model/ mindarmour.egg-info/ -*model/ *MNIST/ *out.data/ *defensed_model/ diff --git a/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py b/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py index 84b3e30..5f21bf6 100644 --- a/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py +++ b/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py @@ -27,7 +27,7 @@ from mindspore.nn.optim.momentum import Momentum from mindarmour.adv_robustness.defenses import AdversarialDefense from mindarmour.fuzz_testing import Fuzzer -from mindarmour.fuzz_testing import ModelCoverageMetrics +from mindarmour.fuzz_testing import KMultisectionNeuronCoverage from mindarmour.utils.logger import LogUtil from examples.common.dataset.data_processing import generate_mnist_dataset @@ -38,33 +38,66 @@ TAG = 'Fuzz_testing and enhance model' LOGGER.set_level('INFO') +def split_dataset(image, label, proportion): + """ + Split the generated fuzz data into train and test set. + """ + indices = np.arange(len(image)) + random.shuffle(indices) + train_length = int(len(image) * proportion) + train_image = [image[i] for i in indices[:train_length]] + train_label = [label[i] for i in indices[:train_length]] + test_image = [image[i] for i in indices[:train_length]] + test_label = [label[i] for i in indices[:train_length]] + return train_image, train_label, test_image, test_label + + def example_lenet_mnist_fuzzing(): """ An example of fuzz testing and then enhance the non-robustness model. """ # upload trained network - ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m1-10_1250.ckpt' + ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' net = LeNet5() load_dict = load_checkpoint(ckpt_path) load_param_into_net(net, load_dict) model = Model(net) - mutate_config = [{'method': 'Blur', - 'params': {'auto_param': [True]}}, - {'method': 'Contrast', - 'params': {'auto_param': [True]}}, - {'method': 'Translate', - 'params': {'auto_param': [True]}}, - {'method': 'Brightness', - 'params': {'auto_param': [True]}}, - {'method': 'Noise', - 'params': {'auto_param': [True]}}, - {'method': 'Scale', - 'params': {'auto_param': [True]}}, - {'method': 'Shear', - 'params': {'auto_param': [True]}}, - {'method': 'FGSM', - 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}} - ] + mutate_config = [ + {'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, + {'method': 'MotionBlur', + 'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}}, + {'method': 'GradientBlur', + 'params': {'point': [[10, 10]], 'auto_param': [True]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'GaussianNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'SaltAndPepperNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'NaturalNoise', + 'params': {'ratio': [0.1], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)], 'auto_param': [False, True]}}, + {'method': 'Contrast', + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'GradientLuminance', + 'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)], + 'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'], + 'auto_param': [False, True]}}, + {'method': 'Translate', + 'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}}, + {'method': 'Scale', + 'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}}, + {'method': 'Shear', + 'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, + {'method': 'Perspective', + 'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]], + 'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}}, + {'method': 'Curve', + 'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}}, + {'method': 'FGSM', + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] # get training data data_list = "../common/dataset/MNIST/train" @@ -75,49 +108,36 @@ def example_lenet_mnist_fuzzing(): images = data[0].astype(np.float32) train_images.append(images) train_images = np.concatenate(train_images, axis=0) - neuron_num = 10 - segmented_num = 1000 - - # initialize fuzz test with training dataset - model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) + segmented_num = 100 # fuzz test with original test data - # get test data data_list = "../common/dataset/MNIST/test" - batch_size = 32 - init_samples = 5000 - max_iters = 50000 + batch_size = batch_size + init_samples = 50 + max_iters = 500 mutate_num_per_seed = 10 - ds = generate_mnist_dataset(data_list, batch_size, num_samples=init_samples, - sparse=False) + ds = generate_mnist_dataset(data_list, batch_size=batch_size, num_samples=init_samples, sparse=False) test_images = [] test_labels = [] for data in ds.create_tuple_iterator(output_numpy=True): - images = data[0].astype(np.float32) - labels = data[1] - test_images.append(images) - test_labels.append(labels) + test_images.append(data[0].astype(np.float32)) + test_labels.append(data[1]) test_images = np.concatenate(test_images, axis=0) test_labels = np.concatenate(test_labels, axis=0) - initial_seeds = [] + + coverage = KMultisectionNeuronCoverage(model, train_images, segmented_num=segmented_num, incremental=True) + kmnc = coverage.get_metrics(test_images[:100]) + print('kmnc: ', kmnc) # make initial seeds + initial_seeds = [] for img, label in zip(test_images, test_labels): initial_seeds.append([img, label]) - model_coverage_test.calculate_coverage( - np.array(test_images[:100]).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of test dataset before fuzzing is : %s', - model_coverage_test.get_kmnc()) - LOGGER.info(TAG, 'NBC of test dataset before fuzzing is : %s', - model_coverage_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of test dataset before fuzzing is : %s', - model_coverage_test.get_snac()) - - model_fuzz_test = Fuzzer(model, train_images, 10, 1000) + model_fuzz_test = Fuzzer(model) gen_samples, gt, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, - initial_seeds, - eval_metrics='auto', + initial_seeds, coverage, + evaluate=True, max_iters=max_iters, mutate_num_per_seed=mutate_num_per_seed) @@ -125,24 +145,10 @@ def example_lenet_mnist_fuzzing(): for key in metrics: LOGGER.info(TAG, key + ': %s', metrics[key]) - def split_dataset(image, label, proportion): - """ - Split the generated fuzz data into train and test set. - """ - indices = np.arange(len(image)) - random.shuffle(indices) - train_length = int(len(image) * proportion) - train_image = [image[i] for i in indices[:train_length]] - train_label = [label[i] for i in indices[:train_length]] - test_image = [image[i] for i in indices[:train_length]] - test_label = [label[i] for i in indices[:train_length]] - return train_image, train_label, test_image, test_label - - train_image, train_label, test_image, test_label = split_dataset( - gen_samples, gt, 0.7) + train_image, train_label, test_image, test_label = split_dataset(gen_samples, gt, 0.7) # load model B and test it on the test set - ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m2-10_1250.ckpt' + ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' net = LeNet5() load_dict = load_checkpoint(ckpt_path) load_param_into_net(net, load_dict) @@ -154,12 +160,11 @@ def example_lenet_mnist_fuzzing(): # enhense model robustness lr = 0.001 momentum = 0.9 - loss_fn = SoftmaxCrossEntropyWithLogits(Sparse=True) + loss_fn = SoftmaxCrossEntropyWithLogits(sparse=True) optimizer = Momentum(net.trainable_params(), lr, momentum) adv_defense = AdversarialDefense(net, loss_fn, optimizer) - adv_defense.batch_defense(np.array(train_image).astype(np.float32), - np.argmax(train_label, axis=1).astype(np.int32)) + adv_defense.batch_defense(np.array(train_image).astype(np.float32), np.argmax(train_label, axis=1).astype(np.int32)) preds_en = net(Tensor(test_image, dtype=mindspore.float32)).asnumpy() acc_en = np.sum(np.argmax(preds_en, axis=1) == np.argmax(test_label, axis=1)) / len(test_label) print('Accuracy of enhensed model on test set is ', acc_en) @@ -167,5 +172,5 @@ def example_lenet_mnist_fuzzing(): if __name__ == '__main__': # device_target can be "CPU", "GPU" or "Ascend" - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") example_lenet_mnist_fuzzing() diff --git a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py index 3a30b6d..53321c3 100644 --- a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py +++ b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py @@ -35,24 +35,50 @@ def test_lenet_mnist_fuzzing(): load_dict = load_checkpoint(ckpt_path) load_param_into_net(net, load_dict) model = Model(net) - mutate_config = [{'method': 'Blur', - 'params': {'radius': [0.1, 0.2, 0.3], - 'auto_param': [True, False]}}, - {'method': 'Contrast', - 'params': {'auto_param': [True]}}, - {'method': 'Translate', - 'params': {'auto_param': [True]}}, - {'method': 'Brightness', - 'params': {'auto_param': [True]}}, - {'method': 'Noise', - 'params': {'auto_param': [True]}}, - {'method': 'Scale', - 'params': {'auto_param': [True]}}, - {'method': 'Shear', - 'params': {'auto_param': [True]}}, - {'method': 'FGSM', - 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}} - ] + mutate_config = [ + {'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], + 'auto_param': [True, False]}}, + {'method': 'MotionBlur', + 'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}}, + {'method': 'GradientBlur', + 'params': {'point': [[10, 10]], 'auto_param': [True]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'GaussianNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'SaltAndPepperNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'NaturalNoise', + 'params': {'ratio': [0.1, 0.2, 0.3], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)], + 'auto_param': [False, True]}}, + {'method': 'Contrast', + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'GradientLuminance', + 'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)], + 'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'], + 'auto_param': [False, True]}}, + {'method': 'Translate', + 'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}}, + {'method': 'Scale', + 'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}}, + {'method': 'Shear', + 'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, + {'method': 'Perspective', + 'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]], + 'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}}, + {'method': 'Curve', + 'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}}, + {'method': 'FGSM', + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}, + {'method': 'PGD', + 'params': {'eps': [0.1, 0.2, 0.4], 'eps_iter': [0.05, 0.1], 'nb_iter': [1, 3]}}, + {'method': 'MDIIM', + 'params': {'eps': [0.1, 0.2, 0.4], 'prob': [0.5, 0.1], + 'norm_level': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', 'linf']}} + ] # get training data data_list = "../common/dataset/MNIST/train" @@ -88,7 +114,10 @@ def test_lenet_mnist_fuzzing(): print('KMNC of initial seeds is: ', kmnc) initial_seeds = initial_seeds[:100] model_fuzz_test = Fuzzer(model) - _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10, + _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, + initial_seeds, coverage, + evaluate=True, + max_iters=10, mutate_num_per_seed=20) if metrics: diff --git a/examples/natural_robustness/natural_robustness_example.py b/examples/natural_robustness/natural_robustness_example.py new file mode 100644 index 0000000..1daf4f1 --- /dev/null +++ b/examples/natural_robustness/natural_robustness_example.py @@ -0,0 +1,176 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example for natural robustness methods.""" + +import numpy as np +import cv2 + +from mindarmour.natural_robustness import Translate, Curve, Perspective, Scale, Shear, Rotate, SaltAndPepperNoise, \ + NaturalNoise, GaussianNoise, UniformNoise, MotionBlur, GaussianBlur, GradientBlur, Contrast, GradientLuminance + + +def test_perspective(image): + """Test perspective.""" + ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]] + dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]] + trans = Perspective(ori_pos, dst_pos) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_uniform_noise(image): + """Test uniform noise.""" + trans = UniformNoise(factor=0.1) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_gaussian_noise(image): + """Test gaussian noise.""" + trans = GaussianNoise(factor=0.1) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_contrast(image): + """Test contrast.""" + trans = Contrast(alpha=0.3, beta=0) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_gaussian_blur(image): + """Test gaussian blur.""" + trans = GaussianBlur(ksize=5) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_salt_and_pepper_noise(image): + """Test salt and pepper noise.""" + trans = SaltAndPepperNoise(factor=0.01) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_translate(image): + """Test translate.""" + trans = Translate(x_bias=0.1, y_bias=0.1) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_scale(image): + """Test scale.""" + trans = Scale(factor_x=0.7, factor_y=0.7) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_shear(image): + """Test shear.""" + trans = Shear(factor=0.2) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_rotate(image): + """Test rotate.""" + trans = Rotate(angle=20) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_curve(image): + """Test curve.""" + trans = Curve(curves=1.5, depth=1.5, mode='horizontal') + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_natural_noise(image): + """Test natural noise.""" + trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_gradient_luminance(image): + """Test gradient luminance.""" + height, width = image.shape[:2] + point = (height // 4, width // 2) + start = (255, 255, 255) + end = (0, 0, 0) + scope = 0.3 + bright_rate = 0.4 + trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate, + mode='horizontal') + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_motion_blur(image): + """Test motion blur.""" + angle = -10.5 + i = 3 + trans = MotionBlur(degree=i, angle=angle) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +def test_gradient_blur(image): + """Test gradient blur.""" + number = 10 + h, w = image.shape[:2] + point = (int(h / 5), int(w / 5)) + center = False + trans = GradientBlur(point, number, center) + dst = trans(image) + cv2.imshow('dst', dst) + cv2.waitKey() + + +if __name__ == '__main__': + img = cv2.imread('1.jpeg') + img = np.array(img) + test_uniform_noise(img) + test_gaussian_noise(img) + test_motion_blur(img) + test_gradient_blur(img) + test_gradient_luminance(img) + test_natural_noise(img) + test_curve(img) + test_rotate(img) + test_shear(img) + test_scale(img) + test_translate(img) + test_salt_and_pepper_noise(img) + test_gaussian_blur(img) + test_constract(img) + test_perspective(img) diff --git a/examples/natural_robustness/serving/README.md b/examples/natural_robustness/serving/README.md new file mode 100644 index 0000000..60c2b54 --- /dev/null +++ b/examples/natural_robustness/serving/README.md @@ -0,0 +1,206 @@ +# 自然扰动样本生成serving + +提供自然扰动样本生成在线服务。客户端传入图片和扰动参数,服务端返回扰动后的图片数据。 + +## 环境准备 + +硬件环境:Ascend 910,GPU + +操作系统:Linux-x86_64 + +软件环境: + +1. python 3.7.5或python 3.9.0 + +2. 安装MindSpore 1.5.0可以参考[MindSpore安装页面](https://www.mindspore.cn/install) + +3. 安装MindSpore Serving 1.5.0可以参考[MindSpore Serving 安装页面](https://www.mindspore.cn/serving/docs/zh-CN/r1.5/serving_install.html) + +4. 安装serving分支的MindArmour: + + - 从Gitee下载源码 + + `git clone https://gitee.com/mindspore/mindarmour.git` + + - 编译并安装MindArmour + + `python setup.py install` + +### 文件结构说明 + +```bash +serving +├── server +│ ├── serving_server.py # 启动serving服务脚本 +│ ├── export_model +│ │ └── add_model.py # 生成模型文件脚本 +│ └── perturbation +│ └── serverable_config.py # 服务端接收客户端数据后的处理脚本 +└── client + ├── serving_client.py # 启动客户端脚本 + └── perturb_config.py # 扰动方法配置文件 +``` + +## 脚本说明及使用 + +### 导出模型 + +在`server/export_model`目录下,使用[add_model.py](https://gitee.com/mindspore/serving/blob/r1.5/example/tensor_add/export_model/add_model.py),构造了一个只有Add算子的tensor加法网络。使用命令 + +```bash +python add_model.py +``` + +在`perturbation`模型文件夹下生成`tensor_add.mindir`模型文件。 + +该服务实际上并没有使用到模型,但目前版本的serving需要有一个模型,serving升级后这部分会删除。 + +### 部署Serving推理服务 + +1. #### `servable_config.py`说明。 + + ```python + ··· + + # 客户端可以请求的方法,包含4个返回值:"results", "file_names", "file_length", "names_dict" + + @register.register_method(output_names=["results", "file_names", "file_length", "names_dict"]) + def natural_perturbation(img, perturb_config, methods_number, outputs_number): + """method natural_perturbation data flow definition, only preprocessing and call model""" + res = register.add_stage(perturb, img, perturb_config, methods_number, outputs_number, outputs_count=4) + return res + ``` + + 方法`natural_perturbation`为对外提供服务的接口。 + + **输入:** + + - img:输入为图片,格式为bytes。 + - perturb_config:扰动配置项,具体配置参考`perturb_config.py`。 + - methods_number:每次扰动随机从配置项中选择方法的个数。 + - outputs_number:对于每张图片,生成的扰动图片数量。 + + **输出**res中包含4个参数: + + - results:拼接后的图像bytes; + + - file_names:图像名,格式为`xxx.png`,其中‘xxx’为A-Za-z中随机选择20个字符构成的字符串。 + + - file_length:每张图片的bytes长度。 + + - names_dict: 图片名和图片使用扰动方法构成的字典。格式为: + + ```bash + { + picture1.png: [[method1, parameters of method1], [method2, parameters of method2], ...]], + picture2.png: [[method3, parameters of method3], [method4, parameters of method4], ...]], + ... + } + ``` + +2. #### 启动server。 + + ```python + ··· + + def start(): + servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) + # 服务配置 + servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="perturbation", device_ids=(0, 1), num_parallel_workers=4) + # 启动服务 + server.start_servables(servable_configs=servable_config) + + # 启动启动gRPC服务,用于客户端和服务端之间通信 + server.start_grpc_server(address="0.0.0.0:5500", max_msg_mb_size=200) # ip和最大的传输数据量,单位MB + # 启动启动Restful服务,用于客户端和服务端之间通信 + server.start_restful_server(address="0.0.0.0:5500") + ``` + + gRPC传输性能更好,Restful更适合用于web服务,根据需要选择。 + + 执行命令`python serverong_server.py`启动服务。 + + 当服务端打印日志`Serving RESTful server start success, listening on 0.0.0.0:5500`时,表示Serving RESTful服务启动成功,推理模型已成功加载。 + +### 客户端进行推理 + +1. 在`perturb_config.py`中设置扰动方法及参数。下面是个例子: + + ```python + PerturbConfig = [{"method": "Contrast", "params": {"alpha": 1.5, "beta": 0}}, + {"method": "GaussianBlur", "params": {"ksize": 5}}, + {"method": "SaltAndPepperNoise", "params": {"factor": 0.05}}, + {"method": "Translate", "params": {"x_bias": 0.1, "y_bias": -0.2}}, + {"method": "Scale", "params": {"factor_x": 0.7, "factor_y": 0.7}}, + {"method": "Shear", "params": {"factor": 2, "director": "horizontal"}}, + {"method": "Rotate", "params": {"angle": 40}}, + {"method": "MotionBlur", "params": {"degree": 5, "angle": 45}}, + {"method": "GradientBlur", "params": {"point": [50, 100], "kernel_num": 3, "center": True}}, + {"method": "GradientLuminance", + "params": {"color_start": [255, 255, 255], + "color_end": [0, 0, 0], + "start_point": [100, 150], "scope": 0.3, + "bright_rate": 0.3, "pattern": "light", + "mode": "circle"}}, + {"method": "Curve", "params": {"curves": 5, "depth": 10, + "mode": "vertical"}}, + {"method": "Perspective", + "params": {"ori_pos": [[0, 0], [0, 800], [800, 0], [800, 800]], + "dst_pos": [[50, 0], [0, 800], [780, 0], [800, 800]]}}, + ] + ``` + + 其中`method`为扰动方法名,`params`为对应方法的参数。可用的扰动方法及对应参数可在`mindarmour/natural_robustness/natural_noise.py`中查询。 + +2. 在`serving_client.py`中写客户端的处理脚本,包含输入输出的处理、服务端的调用,可以参考下面的例子。 + + ```python + ··· + + def perturb(perturb_config): + """invoke servable perturbation method natural_perturbation""" + + # 请求的服务端ip及端口、请求的服务名、请求的方法名 + client = Client("10.175.122.87:5500", "perturbation", "natural_perturbation") + + # 输入数据 + instances = [] + img_path = '/root/mindarmour/example/adversarial/test_data/1.png' + result_path = '/root/mindarmour/example/adv/result/' + methods_number = 2 + outputs_number = 3 + img = cv2.imread(img_path) + img = cv2.imencode('.png', img)[1].tobytes() # 图片传输用bytes格式,不支持numpy.ndarray格式 + perturb_config = json.dumps(perturb_config) # 配置方法转成json格式 + instances.append({"img": img, 'perturb_config': perturb_config, "methods_number": methods_number, + "outputs_number": outputs_number}) # instances中可添加多个输入 + + # 请求服务,返回结果 + result = client.infer(instances) + + # 对服务请求得到的结果进行处理,将返回的图片字节流存成图片 + file_names = result[0]['file_names'].split(';') + length = result[0]['file_length'].tolist() + before = 0 + for name, leng in zip(file_names, length): + res_img = result[0]['results'] + res_img = res_img[before:before + leng] + before = before + leng + print('name: ', name) + image = Image.open(BytesIO(res_img)) + image.save(os.path.join(result_path, name)) + + names_dict = result[0]['names_dict'] + with open('names_dict.json', 'w') as file: + file.write(names_dict) + ``` + + 启动client前,需将服务端的IP地址改成部署server的IP地址,图片路径、结果存储路基替换成用户数据路径。 + + 目前serving数据传输支持的数据类型包括:python的int、float、bool、str、bytes,numpy number, numpy array object。 + + 输入命令`python serving_client.py`开启客户端,如果对应目录下生成扰动样本图片则说明serving服务正确执行。 + + ### 其他 + + 在`serving_logs`目录下可以查看运行日志,辅助debug。 diff --git a/examples/natural_robustness/serving/client/perturb_config.py b/examples/natural_robustness/serving/client/perturb_config.py new file mode 100644 index 0000000..9c6df39 --- /dev/null +++ b/examples/natural_robustness/serving/client/perturb_config.py @@ -0,0 +1,41 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Configuration of natural robustness methods for server. +""" + +perturb_configs = [{"method": "Contrast", "params": {"alpha": 1.5, "beta": 0}}, + {"method": "GaussianBlur", "params": {"ksize": 5}}, + {"method": "SaltAndPepperNoise", "params": {"factor": 0.05}}, + {"method": "Translate", "params": {"x_bias": 0.1, "y_bias": -0.2}}, + {"method": "Scale", "params": {"factor_x": 0.7, "factor_y": 0.7}}, + {"method": "Shear", "params": {"factor": 2, "direction": "horizontal"}}, + {"method": "Rotate", "params": {"angle": 40}}, + {"method": "MotionBlur", "params": {"degree": 5, "angle": 45}}, + {"method": "GradientBlur", "params": {"point": [50, 100], "kernel_num": 3, "center": True}}, + {"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0], + "start_point": [100, 150], "scope": 0.3, + "bright_rate": 0.3, "pattern": "light", + "mode": "circle"}}, + {"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], + "color_end": [0, 0, 0], "start_point": [150, 200], + "scope": 0.3, "pattern": "light", "mode": "horizontal"}}, + {"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0], + "start_point": [150, 200], "scope": 0.3, + "pattern": "light", "mode": "vertical"}}, + {"method": "Curve", "params": {"curves": 10, "depth": 10, "mode": "vertical"}}, + {"method": "Perspective", "params": {"ori_pos": [[0, 0], [0, 800], [800, 0], [800, 800]], + "dst_pos": [[50, 0], [0, 800], [780, 0], [800, 800]]}}, + ] diff --git a/examples/natural_robustness/serving/client/serving_client.py b/examples/natural_robustness/serving/client/serving_client.py new file mode 100644 index 0000000..52fb504 --- /dev/null +++ b/examples/natural_robustness/serving/client/serving_client.py @@ -0,0 +1,61 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The client of example add.""" +import os +import json +from io import BytesIO + +import cv2 +from PIL import Image +from mindspore_serving.client import Client + +from perturb_config import perturb_configs + + +def perturb(perturb_config): + """Invoke servable perturbation method natural_perturbation""" + client = Client("0.0.0.0:5500", "perturbation", "natural_perturbation") + instances = [] + img_path = 'test_data/1.png' + result_path = 'result/' + if not os.path.exists(result_path): + os.mkdir(result_path) + methods_number = 2 + outputs_number = 10 + img = cv2.imread(img_path) + img = cv2.imencode('.png', img)[1].tobytes() + perturb_config = json.dumps(perturb_config) + instances.append({"img": img, 'perturb_config': perturb_config, "methods_number": methods_number, + "outputs_number": outputs_number}) + + result = client.infer(instances) + + file_names = result[0]['file_names'].split(';') + length = result[0]['file_length'].tolist() + before = 0 + for name, leng in zip(file_names, length): + res_img = result[0]['results'] + res_img = res_img[before:before + leng] + before = before + leng + print('name: ', name) + image = Image.open(BytesIO(res_img)) + image.save(os.path.join(result_path, name)) + names_dict = result[0]['names_dict'] + with open('names_dict.json', 'w') as file: + file.write(names_dict) + + +if __name__ == '__main__': + perturb(perturb_configs) diff --git a/examples/natural_robustness/serving/server/export_model/add_model.py b/examples/natural_robustness/serving/server/export_model/add_model.py new file mode 100644 index 0000000..7a14404 --- /dev/null +++ b/examples/natural_robustness/serving/server/export_model/add_model.py @@ -0,0 +1,58 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""add model generator""" +import os +from shutil import copyfile +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops as ops +import mindspore as ms + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + """Define Net of add""" + + def __init__(self): + super(Net, self).__init__() + self.add = ops.Add() + + def construct(self, x_, y_): + """construct add net""" + return self.add(x_, y_) + + +def export_net(): + """Export add net of 2x2 + 2x2, and copy output model `tensor_add.mindir` to directory ../add/1""" + x = np.ones([2, 2]).astype(np.float32) + y = np.ones([2, 2]).astype(np.float32) + add = Net() + ms.export(add, ms.Tensor(x), ms.Tensor(y), file_name='tensor_add', file_format='MINDIR') + dst_dir = '../perturbation/1' + try: + os.mkdir(dst_dir) + except OSError: + pass + + dst_file = os.path.join(dst_dir, 'tensor_add.mindir') + copyfile('tensor_add.mindir', dst_file) + print("copy tensor_add.mindir to " + dst_dir + " success") + + +if __name__ == "__main__": + export_net() diff --git a/examples/natural_robustness/serving/server/perturbation/servable_config.py b/examples/natural_robustness/serving/server/perturbation/servable_config.py new file mode 100644 index 0000000..f343c2a --- /dev/null +++ b/examples/natural_robustness/serving/server/perturbation/servable_config.py @@ -0,0 +1,109 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""perturbation servable config""" +import json +import copy +import random +from io import BytesIO +import cv2 +import numpy as np +from PIL import Image +from mindspore_serving.server import register +from mindarmour.natural_robustness import Contrast, GaussianBlur, SaltAndPepperNoise, Scale, Shear, \ + Translate, Rotate, MotionBlur, GradientBlur, GradientLuminance, NaturalNoise, Curve, Perspective + + +CHARACTERS = [chr(i) for i in range(65, 91)]+[chr(j) for j in range(97, 123)] + +methods_dict = {'Contrast': Contrast, + 'GaussianBlur': GaussianBlur, + 'SaltAndPepperNoise': SaltAndPepperNoise, + 'Translate': Translate, + 'Scale': Scale, + 'Shear': Shear, + 'Rotate': Rotate, + 'MotionBlur': MotionBlur, + 'GradientBlur': GradientBlur, + 'GradientLuminance': GradientLuminance, + 'NaturalNoise': NaturalNoise, + 'Curve': Curve, + 'Perspective': Perspective} + + +def check_inputs(img, perturb_config, methods_number, outputs_number): + """Check inputs.""" + if not np.any(img): + raise ValueError("img cannot be empty.") + img = Image.open(BytesIO(img)) + img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) + + config = json.loads(perturb_config) + if not config: + raise ValueError("perturb_config cannot be empty.") + for item in config: + if item['method'] not in methods_dict.keys(): + raise ValueError("{} is not a valid method.".format(item['method'])) + + methods_number = int(methods_number) + if methods_number < 1: + raise ValueError("methods_number must more than 0.") + outputs_number = int(outputs_number) + if outputs_number < 1: + raise ValueError("outputs_number must more than 0.") + + return img, config, methods_number, outputs_number + + +def perturb(img, perturb_config, methods_number, outputs_number): + """Perturb given image.""" + img, config, methods_number, outputs_number = check_inputs(img, perturb_config, methods_number, outputs_number) + res_img_bytes = b'' + file_names = [] + file_length = [] + names_dict = {} + for _ in range(outputs_number): + dst = copy.deepcopy(img) + used_methods = [] + for _ in range(methods_number): + item = np.random.choice(config) + method_name = item['method'] + method = methods_dict[method_name] + params = item['params'] + dst = method(**params)(img) + method_params = params + + used_methods.append([method_name, method_params]) + name = ''.join(random.sample(CHARACTERS, 20)) + name += '.png' + file_names.append(name) + names_dict[name] = used_methods + + res_img = cv2.imencode('.png', dst)[1].tobytes() + res_img_bytes += res_img + file_length.append(len(res_img)) + + names_dict = json.dumps(names_dict) + + return res_img_bytes, ';'.join(file_names), file_length, names_dict + + +model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) + + +@register.register_method(output_names=["results", "file_names", "file_length", "names_dict"]) +def natural_perturbation(img, perturb_config, methods_number, outputs_number): + """method natural_perturbation data flow definition, only preprocessing and call model""" + res = register.add_stage(perturb, img, perturb_config, methods_number, outputs_number, outputs_count=4) + return res diff --git a/examples/natural_robustness/serving/server/serving_server.py b/examples/natural_robustness/serving/server/serving_server.py new file mode 100644 index 0000000..8ccc705 --- /dev/null +++ b/examples/natural_robustness/serving/server/serving_server.py @@ -0,0 +1,35 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The server of example perturbation""" + +import os +import sys +from mindspore_serving import server + + +def start(): + """Start server.""" + servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) + + servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="perturbation", + device_ids=(0, 1), num_parallel_workers=4) + server.start_servables(servable_configs=servable_config) + + server.start_grpc_server(address="0.0.0.0:5500", max_msg_mb_size=200) + # server.start_restful_server(address="0.0.0.0:5500") + + +if __name__ == "__main__": + start() diff --git a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py index d95e5a8..a866039 100644 --- a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py +++ b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py @@ -92,7 +92,7 @@ def _projection(values, eps, norm_level): return proj_flat.reshape(values.shape) if norm_level in (2, '2'): return eps*normalize_value(values, norm_level) - if norm_level in (np.inf, 'inf'): + if norm_level in (np.inf, 'inf', 'linf', 'np.inf'): return eps*np.sign(values) msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \ 'currently not supported.' @@ -277,7 +277,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): nb_iter (int): Number of iteration. Default: 5. decay_factor (float): Decay factor in iterations. Default: 1.0. norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: - np.inf, 1 or 2. Default: 'inf'. + 1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', np.inf and 'linf'. Default: 'inf'. loss_fn (Loss): Loss function for optimization. If None, the input network \ is already equipped with loss function. Default: None. @@ -423,7 +423,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): attack. Default: False. nb_iter (int): Number of iteration. Default: 5. norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: - np.inf, 1 or 2. Default: 'inf'. + 1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', np.inf and 'linf'. Default: 'inf'. loss_fn (Loss): Loss function for optimization. If None, the input network \ is already equipped with loss function. Default: None. @@ -577,7 +577,7 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): is_targeted (bool): If True, targeted attack. If False, untargeted attack. Default: False. norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: - np.inf, 1 or 2. Default: 'l1'. + 1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', np.inf and 'linf'. Default: 'l1'. prob (float): Transformation probability. Default: 0.5. loss_fn (Loss): Loss function for optimization. If None, the input network \ is already equipped with loss function. Default: None. diff --git a/mindarmour/fuzz_testing/fuzzing.py b/mindarmour/fuzz_testing/fuzzing.py index 93dafce..508e133 100644 --- a/mindarmour/fuzz_testing/fuzzing.py +++ b/mindarmour/fuzz_testing/fuzzing.py @@ -24,10 +24,10 @@ from mindspore import nn from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \ check_param_in_range, check_param_type, check_int_positive, check_param_bounds from mindarmour.utils.logger import LogUtil -from ..adv_robustness.attacks import FastGradientSignMethod, \ +from mindarmour.adv_robustness.attacks import FastGradientSignMethod, \ MomentumDiverseInputIterativeMethod, ProjectedGradientDescent -from .image_transform import Contrast, Brightness, Blur, \ - Noise, Translate, Scale, Shear, Rotate +from mindarmour.natural_robustness import GaussianBlur, MotionBlur, GradientBlur, UniformNoise, GaussianNoise, \ + SaltAndPepperNoise, NaturalNoise, Contrast, GradientLuminance, Translate, Scale, Shear, Rotate, Perspective, Curve from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage LOGGER = LogUtil.get_instance() @@ -104,17 +104,79 @@ class Fuzzer: target_model (Model): Target fuzz model. Examples: + >>> import numpy as np + >>> from mindspore import context + >>> from mindspore import nn + >>> from mindspore.common.initializer import TruncatedNormal + >>> from mindspore.ops import operations as P + >>> from mindspore.train import Model + >>> from mindspore.ops import TensorSummary + >>> from mindarmour.fuzz_testing import Fuzzer + >>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage + >>> + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid") + >>> self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid") + >>> self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02)) + >>> self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02)) + >>> self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02)) + >>> self.relu = nn.ReLU() + >>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + >>> self.reshape = P.Reshape() + >>> self.summary = TensorSummary() + >>> + >>> def construct(self, x): + >>> x = self.conv1(x) + >>> x = self.relu(x) + >>> self.summary('conv1', x) + >>> x = self.max_pool2d(x) + >>> x = self.conv2(x) + >>> x = self.relu(x) + >>> self.summary('conv2', x) + >>> x = self.max_pool2d(x) + >>> x = self.reshape(x, (-1, 16 * 5 * 5)) + >>> x = self.fc1(x) + >>> x = self.relu(x) + >>> self.summary('fc1', x) + >>> x = self.fc2(x) + >>> x = self.relu(x) + >>> self.summary('fc2', x) + >>> x = self.fc3(x) + >>> self.summary('fc3', x) + >>> return x + >>> >>> net = Net() >>> model = Model(net) - >>> mutate_config = [{'method': 'Blur', - ... 'params': {'auto_param': [True]}}, + >>> mutate_config = [{'method': 'GaussianBlur', + ... 'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, + ... {'method': 'MotionBlur', + ... 'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], + ... 'auto_param': [True]}}, + ... {'method': 'UniformNoise', + ... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + ... {'method': 'GaussianNoise', + ... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, ... {'method': 'Contrast', - ... 'params': {'factor': [2]}}, - ... {'method': 'Translate', - ... 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}}, + ... 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + ... {'method': 'Rotate', + ... 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, ... {'method': 'FGSM', - ... 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}] - >>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) + ... 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] + >>> batch_size = 8 + >>> num_classe = 10 + >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) + >>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) + >>> test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) + >>> test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) + >>> initial_seeds = [] + >>> # make initial seeds + >>> for img, label in zip(test_images, test_labels): + >>> initial_seeds.append([img, label]) + + >>> initial_seeds = initial_seeds[:10] + >>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100, incremental=True) >>> model_fuzz_test = Fuzzer(model) >>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, ... nc, max_iters=100) @@ -125,18 +187,26 @@ class Fuzzer: # Allowed mutate strategies so far. self._strategies = {'Contrast': Contrast, - 'Brightness': Brightness, - 'Blur': Blur, - 'Noise': Noise, + 'GradientLuminance': GradientLuminance, + 'GaussianBlur': GaussianBlur, + 'MotionBlur': MotionBlur, + 'GradientBlur': GradientBlur, + 'UniformNoise': UniformNoise, + 'GaussianNoise': GaussianNoise, + 'SaltAndPepperNoise': SaltAndPepperNoise, + 'NaturalNoise': NaturalNoise, 'Translate': Translate, 'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, + 'Perspective': Perspective, + 'Curve': Curve, 'FGSM': FastGradientSignMethod, 'PGD': ProjectedGradientDescent, 'MDIIM': MomentumDiverseInputIterativeMethod} - self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] - self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', 'Noise'] + self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate', 'Perspective', 'Curve'] + self._pixel_value_trans_list = ['Contrast', 'GradientLuminance', 'GaussianBlur', 'MotionBlur', 'GradientBlur', + 'UniformNoise', 'GaussianNoise', 'SaltAndPepperNoise', 'NaturalNoise'] self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] self._attack_param_checklists = { 'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]}, @@ -144,10 +214,11 @@ class Fuzzer: 'bounds': {'dtype': [tuple, list]}}, 'PGD': {'eps': {'dtype': [float], 'range': [0, 1]}, 'eps_iter': {'dtype': [float], 'range': [0, 1]}, - 'nb_iter': {'dtype': [int], 'range': [0, 100000]}, + 'nb_iter': {'dtype': [int]}, 'bounds': {'dtype': [tuple, list]}}, 'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]}, - 'norm_level': {'dtype': [str, int], 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf']}, + 'norm_level': {'dtype': [str, int], + 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf']}, 'prob': {'dtype': [float], 'range': [0, 1]}, 'bounds': {'dtype': [tuple, list]}}} @@ -157,18 +228,26 @@ class Fuzzer: Args: mutate_config (list): Mutate configs. The format is - [{'method': 'Blur', - 'params': {'radius': [0.1, 0.2], 'auto_param': [True, False]}}, - {'method': 'Contrast', - 'params': {'factor': [1, 1.5, 2]}}, - {'method': 'FGSM', - 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}}, - ...]. + [{'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'GaussianNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'Contrast', + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, + {'method': 'FGSM', + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] + ...]. The supported methods list is in `self._strategies`, and the params of each method must within the range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM', - 'PGD' and 'MDIIM'. `mutate_config` must have method in the type of pixel value based transform methods. + 'PGD' and 'MDIIM'. 'FGSM', 'PGD' and 'MDIIM'. are abbreviations of FastGradientSignMethod, + ProjectedGradientDescent and MomentumDiverseInputIterativeMethod. + `mutate_config` must have method in the type of pixel value based transform methods. The way of setting parameters for first and second type methods can be seen in 'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to `self._attack_param_checklists`. @@ -278,7 +357,6 @@ class Fuzzer: if only_pixel_trans: while strategy['method'] not in self._pixel_value_trans_list: strategy = choice(mutate_config) - transform = mutates[strategy['method']] params = strategy['params'] method = strategy['method'] selected_param = {} @@ -290,9 +368,10 @@ class Fuzzer: shear_keys = selected_param.keys() if 'factor_x' in shear_keys and 'factor_y' in shear_keys: selected_param[choice(['factor_x', 'factor_y'])] = 0 - transform.set_params(**selected_param) - mutate_sample = transform.transform(seed[0]) + transform = mutates[strategy['method']](**selected_param) + mutate_sample = transform(seed[0]) else: + transform = mutates[strategy['method']] for param_name in selected_param: transform.__setattr__('_' + str(param_name), selected_param[param_name]) mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0] @@ -360,6 +439,8 @@ class Fuzzer: _ = check_param_bounds('bounds', param_value) elif param_name == 'norm_level': _ = check_norm_level(param_value) + elif param_name == 'nb_iter': + _ = check_int_positive(param_name, param_value) else: allow_type = self._attack_param_checklists[method][param_name]['dtype'] allow_range = self._attack_param_checklists[method][param_name]['range'] @@ -372,7 +453,8 @@ class Fuzzer: for mutate in mutate_config: method = mutate['method'] if method not in self._attacks_list: - mutates[method] = self._strategies[method]() + # mutates[method] = self._strategies[method]() + mutates[method] = self._strategies[method] else: network = self._target_model._network loss_fn = self._target_model._loss_fn @@ -414,7 +496,6 @@ class Fuzzer: else: attack_success_rate = None metrics_report['Attack_success_rate'] = attack_success_rate - metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples) return metrics_report diff --git a/mindarmour/fuzz_testing/image_transform.py b/mindarmour/fuzz_testing/image_transform.py deleted file mode 100644 index 52a1136..0000000 --- a/mindarmour/fuzz_testing/image_transform.py +++ /dev/null @@ -1,609 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Image transform -""" -import numpy as np -from PIL import Image, ImageEnhance, ImageFilter - -from mindspore.dataset.vision.py_transforms_util import is_numpy, \ - to_pil, hwc_to_chw -from mindarmour.utils._check_param import check_param_multi_types, check_param_in_range, check_numpy_param -from mindarmour.utils.logger import LogUtil - -LOGGER = LogUtil.get_instance() -TAG = 'Image Transformation' - - -def chw_to_hwc(img): - """ - Transpose the input image; shape (C, H, W) to shape (H, W, C). - - Args: - img (numpy.ndarray): Image to be converted. - - Returns: - img (numpy.ndarray), Converted image. - """ - if is_numpy(img): - return img.transpose(1, 2, 0).copy() - raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) - - -def is_hwc(img): - """ - Check if the input image is shape (H, W, C). - - Args: - img (numpy.ndarray): Image to be checked. - - Returns: - Bool, True if input is shape (H, W, C). - """ - if is_numpy(img): - img_shape = np.shape(img) - if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3: - return True - return False - raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) - - -def is_chw(img): - """ - Check if the input image is shape (H, W, C). - - Args: - img (numpy.ndarray): Image to be checked. - - Returns: - Bool, True if input is shape (H, W, C). - """ - if is_numpy(img): - img_shape = np.shape(img) - if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3: - return True - return False - raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) - - -def is_rgb(img): - """ - Check if the input image is RGB. - - Args: - img (numpy.ndarray): Image to be checked. - - Returns: - Bool, True if input is RGB. - """ - if is_numpy(img): - img_shape = np.shape(img) - if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3): - return True - return False - raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) - - -def is_normalized(img): - """ - Check if the input image is normalized between 0 to 1. - - Args: - img (numpy.ndarray): Image to be checked. - - Returns: - Bool, True if input is normalized between 0 to 1. - """ - if is_numpy(img): - minimal = np.min(img) - maximun = np.max(img) - if minimal >= 0 and maximun <= 1: - return True - return False - raise TypeError('img should be Numpy array. Got {}'.format(type(img))) - - -class ImageTransform: - """ - The abstract base class for all image transform classes. - """ - - def __init__(self): - pass - - def _check(self, image): - """ Check image format. If input image is RGB and its shape - is (C, H, W), it will be transposed to (H, W, C). If the value - of the image is not normalized , it will be normalized between 0 to 1.""" - rgb = is_rgb(image) - chw = False - gray3dim = False - normalized = is_normalized(image) - if rgb: - chw = is_chw(image) - if chw: - image = chw_to_hwc(image) - else: - image = image - else: - if len(np.shape(image)) == 3: - gray3dim = True - image = image[0] - else: - image = image - if normalized: - image = image*255 - return rgb, chw, normalized, gray3dim, np.uint8(image) - - def _original_format(self, image, chw, normalized, gray3dim): - """ Return transformed image with original format. """ - if not is_numpy(image): - image = np.array(image) - if chw: - image = hwc_to_chw(image) - if normalized: - image = image / 255 - if gray3dim: - image = np.expand_dims(image, 0) - return image - - def transform(self, image): - pass - - -class Contrast(ImageTransform): - """ - Contrast of an image. - - Args: - factor (Union[float, int]): Control the contrast of an image. If 1.0, - gives the original image. If 0, gives a gray image. Default: 1. - """ - - def __init__(self, factor=1): - super(Contrast, self).__init__() - self.set_params(factor) - - def set_params(self, factor=1, auto_param=False): - """ - Set contrast parameters. - - Args: - factor (Union[float, int]): Control the contrast of an image. If 1.0 - gives the original image. If 0 gives a gray image. Default: 1. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.factor = np.random.uniform(-5, 5) - else: - self.factor = check_param_multi_types('factor', factor, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image (numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - image = to_pil(image) - img_contrast = ImageEnhance.Contrast(image) - trans_image = img_contrast.enhance(self.factor) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - - return trans_image.astype(ori_dtype) - - -class Brightness(ImageTransform): - """ - Brightness of an image. - - Args: - factor (Union[float, int]): Control the brightness of an image. If 1.0 - gives the original image. If 0 gives a black image. Default: 1. - """ - - def __init__(self, factor=1): - super(Brightness, self).__init__() - self.set_params(factor) - - def set_params(self, factor=1, auto_param=False): - """ - Set brightness parameters. - - Args: - factor (Union[float, int]): Control the brightness of an image. If 1 - gives the original image. If 0 gives a black image. Default: 1. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.factor = np.random.uniform(0, 5) - else: - self.factor = check_param_multi_types('factor', factor, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image (numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - image = to_pil(image) - img_contrast = ImageEnhance.Brightness(image) - trans_image = img_contrast.enhance(self.factor) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Blur(ImageTransform): - """ - Blurs the image using Gaussian blur filter. - - Args: - radius(Union[float, int]): Blur radius, 0 means no blur. Default: 0. - """ - - def __init__(self, radius=0): - super(Blur, self).__init__() - self.set_params(radius) - - def set_params(self, radius=0, auto_param=False): - """ - Set blur parameters. - - Args: - radius (Union[float, int]): Blur radius, 0 means no blur. Default: 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.radius = np.random.uniform(-1.5, 1.5) - else: - self.radius = check_param_multi_types('radius', radius, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image (numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - image = to_pil(image) - trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Noise(ImageTransform): - """ - Add noise of an image. - - Args: - factor (float): factor is the ratio of pixels to add noise. - If 0 gives the original image. Default 0. - """ - - def __init__(self, factor=0): - super(Noise, self).__init__() - self.set_params(factor) - - def set_params(self, factor=0, auto_param=False): - """ - Set noise parameters. - - Args: - factor (Union[float, int]): factor is the ratio of pixels to - add noise. If 0 gives the original image. Default 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.factor = np.random.uniform(0, 1) - else: - self.factor = check_param_multi_types('factor', factor, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image (numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) - trans_image = np.copy(image) - threshold = 1 - self.factor - trans_image[noise < -threshold] = 0 - trans_image[noise > threshold] = 1 - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Translate(ImageTransform): - """ - Translate an image. - - Args: - x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length. - Default: 0. - y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide. - Default: 0. - """ - - def __init__(self, x_bias=0, y_bias=0): - super(Translate, self).__init__() - self.set_params(x_bias, y_bias) - - def set_params(self, x_bias=0, y_bias=0, auto_param=False): - """ - Set translate parameters. - - Args: - x_bias (Union[float, int]): X-direction translation, and x_bias should be in range of (-1, 1). Default: 0. - y_bias (Union[float, int]): Y-direction translation, and y_bias should be in range of (-1, 1). Default: 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - x_bias = check_param_in_range('x_bias', x_bias, -1, 1) - y_bias = check_param_in_range('y_bias', y_bias, -1, 1) - self.auto_param = auto_param - if auto_param: - self.x_bias = np.random.uniform(-0.3, 0.3) - self.y_bias = np.random.uniform(-0.3, 0.3) - else: - self.x_bias = check_param_multi_types('x_bias', x_bias, - [int, float]) - self.y_bias = check_param_multi_types('y_bias', y_bias, - [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image(numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - img = to_pil(image) - image_shape = np.shape(image) - self.x_bias = image_shape[1]*self.x_bias - self.y_bias = image_shape[0]*self.y_bias - trans_image = img.transform(img.size, Image.AFFINE, - (1, 0, self.x_bias, 0, 1, self.y_bias)) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Scale(ImageTransform): - """ - Scale an image in the middle. - - Args: - factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. - Default: 1. - factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. - Default: 1. - """ - - def __init__(self, factor_x=1, factor_y=1): - super(Scale, self).__init__() - self.set_params(factor_x, factor_y) - - def set_params(self, factor_x=1, factor_y=1, auto_param=False): - - """ - Set scale parameters. - - Args: - factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. - Default: 1. - factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. - Default: 1. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.factor_x = np.random.uniform(0.7, 3) - self.factor_y = np.random.uniform(0.7, 3) - else: - self.factor_x = check_param_multi_types('factor_x', factor_x, - [int, float]) - self.factor_y = check_param_multi_types('factor_y', factor_y, - [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image(numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - rgb, chw, normalized, gray3dim, image = self._check(image) - if rgb: - h, w, _ = np.shape(image) - else: - h, w = np.shape(image) - move_x_centor = w / 2*(1 - self.factor_x) - move_y_centor = h / 2*(1 - self.factor_y) - img = to_pil(image) - trans_image = img.transform(img.size, Image.AFFINE, - (self.factor_x, 0, move_x_centor, - 0, self.factor_y, move_y_centor)) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Shear(ImageTransform): - """ - Shear an image, for each pixel (x, y) in the sheared image, the new value is - taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then - the sheared image will be rescaled to fit original size. - - Args: - factor_x (Union[float, int]): Shear factor of horizontal direction. - Default: 0. - factor_y (Union[float, int]): Shear factor of vertical direction. - Default: 0. - - """ - - def __init__(self, factor_x=0, factor_y=0): - super(Shear, self).__init__() - self.set_params(factor_x, factor_y) - - def set_params(self, factor_x=0, factor_y=0, auto_param=False): - """ - Set shear parameters. - - Args: - factor_x (Union[float, int]): Shear factor of horizontal direction. - Default: 0. - factor_y (Union[float, int]): Shear factor of vertical direction. - Default: 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if factor_x != 0 and factor_y != 0: - msg = 'At least one of factor_x and factor_y is zero.' - LOGGER.error(TAG, msg) - raise ValueError(msg) - if auto_param: - if np.random.uniform(-1, 1) > 0: - self.factor_x = np.random.uniform(-2, 2) - self.factor_y = 0 - else: - self.factor_x = 0 - self.factor_y = np.random.uniform(-2, 2) - else: - self.factor_x = check_param_multi_types('factor', factor_x, - [int, float]) - self.factor_y = check_param_multi_types('factor', factor_y, - [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image(numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - rgb, chw, normalized, gray3dim, image = self._check(image) - img = to_pil(image) - if rgb: - h, w, _ = np.shape(image) - else: - h, w = np.shape(image) - if self.factor_x != 0: - boarder_x = [0, -w, -self.factor_x*h, -w - self.factor_x*h] - min_x = min(boarder_x) - max_x = max(boarder_x) - scale = (max_x - min_x) / w - move_x_cen = (w - scale*w - scale*h*self.factor_x) / 2 - move_y_cen = h*(1 - scale) / 2 - else: - boarder_y = [0, -h, -self.factor_y*w, -h - self.factor_y*w] - min_y = min(boarder_y) - max_y = max(boarder_y) - scale = (max_y - min_y) / h - move_y_cen = (h - scale*h - scale*w*self.factor_y) / 2 - move_x_cen = w*(1 - scale) / 2 - trans_image = img.transform(img.size, Image.AFFINE, - (scale, scale*self.factor_x, move_x_cen, - scale*self.factor_y, scale, move_y_cen)) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Rotate(ImageTransform): - """ - Rotate an image of degrees counter clockwise around its center. - - Args: - angle(Union[float, int]): Degrees counter clockwise. Default: 0. - """ - - def __init__(self, angle=0): - super(Rotate, self).__init__() - self.set_params(angle) - - def set_params(self, angle=0, auto_param=False): - """ - Set rotate parameters. - - Args: - angle(Union[float, int]): Degrees counter clockwise. Default: 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.angle = np.random.uniform(0, 360) - else: - self.angle = check_param_multi_types('angle', angle, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image(numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - img = to_pil(image) - trans_image = img.rotate(self.angle, expand=False) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) diff --git a/mindarmour/fuzz_testing/model_coverage_metrics.py b/mindarmour/fuzz_testing/model_coverage_metrics.py index f71ee7a..4beab03 100644 --- a/mindarmour/fuzz_testing/model_coverage_metrics.py +++ b/mindarmour/fuzz_testing/model_coverage_metrics.py @@ -154,13 +154,48 @@ class NeuronCoverage(CoverageMetrics): incremental (bool): Metrics will be calculate in incremental way or not. Default: False. batch_size (int): The number of samples in a fuzz test batch. Default: 32. + Examples: + >>> import numpy as np + >>> from mindspore import nn + >>> from mindspore.nn import Cell + >>> from mindspore.train import Model + >>> from mindspore import context + >>> from mindspore.ops import TensorSummary + >>> from mindarmour.fuzz_testing import NeuronCoverage + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> self.summary = TensorSummary() + >>> + >>> def construct(self, inputs): + >>> self.summary('input', inputs) + >>> out = self._relu(inputs) + >>> self.summary('1', out) + >>> return out + >>> + >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + >>> # load network + >>> net = Net() + >>> model = Model(net) + >>> + >>> # initialize fuzz test with training dataset + >>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) + >>> + >>> # fuzz test with original test data + >>> # get test data + >>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) + >>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) + >>> + >>> nc = NeuronCoverage(model, threshold=0.1) + >>> nc_metric = nc.get_metrics(test_data) """ def __init__(self, model, threshold=0.1, incremental=False, batch_size=32): super(NeuronCoverage, self).__init__(model, incremental, batch_size) threshold = check_param_type('threshold', threshold, float) self.threshold = check_value_positive('threshold', threshold) - def get_metrics(self, dataset): """ Get the metric of neuron coverage: the proportion of activated neurons to total neurons in the network. @@ -170,10 +205,6 @@ class NeuronCoverage(CoverageMetrics): Returns: float, the metric of 'neuron coverage'. - - Examples: - >>> nc = NeuronCoverage(model, threshold=0.1) - >>> nc_metrics = nc.get_metrics(test_data) """ dataset = check_numpy_param('dataset', dataset) batches = math.ceil(dataset.shape[0] / self.batch_size) @@ -203,6 +234,43 @@ class TopKNeuronCoverage(CoverageMetrics): top_k (int): Neuron is activated when its output has the top k largest value in that hidden layers. Default: 3. incremental (bool): Metrics will be calculate in incremental way or not. Default: False. batch_size (int): The number of samples in a fuzz test batch. Default: 32. + + Examples: + >>> import numpy as np + >>> from mindspore import nn + >>> from mindspore.nn import Cell + >>> from mindspore.train import Model + >>> from mindspore import context + >>> from mindspore.ops import TensorSummary + >>> from mindarmour.fuzz_testing import TopKNeuronCoverage + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> self.summary = TensorSummary() + >>> + >>> def construct(self, inputs): + >>> self.summary('input', inputs) + >>> out = self._relu(inputs) + >>> self.summary('1', out) + >>> return out + >>> + >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + >>> # load network + >>> net = Net() + >>> model = Model(net) + >>> + >>> # initialize fuzz test with training dataset + >>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) + >>> + >>> # fuzz test with original test data + >>> # get test data + >>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) + >>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) + >>> + >>> tknc = TopKNeuronCoverage(model, top_k=3) + >>> tknc_metrics = tknc.get_metrics(test_data) """ def __init__(self, model, top_k=3, incremental=False, batch_size=32): super(TopKNeuronCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) @@ -217,10 +285,6 @@ class TopKNeuronCoverage(CoverageMetrics): Returns: float, the metrics of 'top k neuron coverage'. - - Examples: - >>> tknc = TopKNeuronCoverage(model, top_k=3) - >>> metrics = tknc.get_metrics(test_data) """ dataset = check_numpy_param('dataset', dataset) batches = math.ceil(dataset.shape[0] / self.batch_size) @@ -252,6 +316,43 @@ class SuperNeuronActivateCoverage(CoverageMetrics): train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. incremental (bool): Metrics will be calculate in incremental way or not. Default: False. batch_size (int): The number of samples in a fuzz test batch. Default: 32. + + Examples: + >>> import numpy as np + >>> from mindspore import nn + >>> from mindspore.nn import Cell + >>> from mindspore.train import Model + >>> from mindspore import context + >>> from mindspore.ops import TensorSummary + >>> from mindarmour.fuzz_testing import SuperNeuronActivateCoverage + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> self.summary = TensorSummary() + >>> + >>> def construct(self, inputs): + >>> self.summary('input', inputs) + >>> out = self._relu(inputs) + >>> self.summary('1', out) + >>> return out + >>> + >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + >>> # load network + >>> net = Net() + >>> model = Model(net) + >>> + >>> # initialize fuzz test with training dataset + >>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) + >>> + >>> # fuzz test with original test data + >>> # get test data + >>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) + >>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) + >>> + >>> snac = SuperNeuronActivateCoverage(model, training_data) + >>> snac_metrics = snac.get_metrics(test_data) """ def __init__(self, model, train_dataset, incremental=False, batch_size=32): super(SuperNeuronActivateCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) @@ -267,10 +368,6 @@ class SuperNeuronActivateCoverage(CoverageMetrics): Returns: float, the metric of 'strong neuron activation coverage'. - - Examples: - >>> snac = SuperNeuronActivateCoverage(model, train_dataset) - >>> metrics = snac.get_metrics(test_data) """ dataset = check_numpy_param('dataset', dataset) if not self.incremental or not self._activate_table: @@ -303,6 +400,43 @@ class NeuronBoundsCoverage(SuperNeuronActivateCoverage): train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. incremental (bool): Metrics will be calculate in incremental way or not. Default: False. batch_size (int): The number of samples in a fuzz test batch. Default: 32. + + Examples: + >>> import numpy as np + >>> from mindspore import nn + >>> from mindspore.nn import Cell + >>> from mindspore.train import Model + >>> from mindspore import context + >>> from mindspore.ops import TensorSummary + >>> from mindarmour.fuzz_testing import NeuronBoundsCoverage + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> self.summary = TensorSummary() + >>> + >>> def construct(self, inputs): + >>> self.summary('input', inputs) + >>> out = self._relu(inputs) + >>> self.summary('1', out) + >>> return out + >>> + >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + >>> # load network + >>> net = Net() + >>> model = Model(net) + >>> + >>> # initialize fuzz test with training dataset + >>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) + >>> + >>> # fuzz test with original test data + >>> # get test data + >>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) + >>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) + >>> + >>> nbc = NeuronBoundsCoverage(model, training_data) + >>> nbc_metrics = nbc.get_metrics(test_data) """ def __init__(self, model, train_dataset, incremental=False, batch_size=32): @@ -317,10 +451,6 @@ class NeuronBoundsCoverage(SuperNeuronActivateCoverage): Returns: float, the metric of 'neuron boundary coverage'. - - Examples: - >>> nbc = NeuronBoundsCoverage(model, train_dataset) - >>> metrics = nbc.get_metrics(test_data) """ dataset = check_numpy_param('dataset', dataset) if not self.incremental or not self._activate_table: @@ -353,6 +483,43 @@ class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage): segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 100. incremental (bool): Metrics will be calculate in incremental way or not. Default: False. batch_size (int): The number of samples in a fuzz test batch. Default: 32. + + Examples: + >>> import numpy as np + >>> from mindspore import nn + >>> from mindspore.nn import Cell + >>> from mindspore.train import Model + >>> from mindspore import context + >>> from mindspore.ops import TensorSummary + >>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> self.summary = TensorSummary() + >>> + >>> def construct(self, inputs): + >>> self.summary('input', inputs) + >>> out = self._relu(inputs) + >>> self.summary('1', out) + >>> return out + >>> + >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + >>> # load network + >>> net = Net() + >>> model = Model(net) + >>> + >>> # initialize fuzz test with training dataset + >>> training_data = (np.random.random((10000, 10))*20).astype(np.float32) + >>> + >>> # fuzz test with original test data + >>> # get test data + >>> test_data = (np.random.random((2000, 10))*20).astype(np.float32) + >>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32) + >>> + >>> kmnc = KMultisectionNeuronCoverage(model, training_data, segmented_num=100) + >>> kmnc_metrics = kmnc.get_metrics(test_data) """ def __init__(self, model, train_dataset, segmented_num=100, incremental=False, batch_size=32): @@ -381,10 +548,6 @@ class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage): Returns: float, the metric of 'k-multisection neuron coverage'. - - Examples: - >>> kmnc = KMultisectionNeuronCoverage(model, train_dataset, segmented_num=100) - >>> metrics = kmnc.get_metrics(test_data) """ dataset = check_numpy_param('dataset', dataset) diff --git a/mindarmour/natural_robustness/__init__.py b/mindarmour/natural_robustness/__init__.py new file mode 100644 index 0000000..90cd7d0 --- /dev/null +++ b/mindarmour/natural_robustness/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This package include methods to generate natural perturbation samples. +""" + +from .transformation import Translate, Scale, Shear, Rotate, Perspective, Curve +from .blur import GaussianBlur, MotionBlur, GradientBlur +from .luminance import Contrast, GradientLuminance +from .corruption import UniformNoise, GaussianNoise, SaltAndPepperNoise, NaturalNoise + +__all__ = ['Translate', + 'Scale', + 'Shear', + 'Rotate', + 'Perspective', + 'Curve', + 'GaussianBlur', + 'MotionBlur', + 'GradientBlur', + 'Contrast', + 'GradientLuminance', + 'UniformNoise', + 'GaussianNoise', + 'SaltAndPepperNoise', + 'NaturalNoise'] diff --git a/mindarmour/natural_robustness/blur.py b/mindarmour/natural_robustness/blur.py new file mode 100644 index 0000000..e06efa0 --- /dev/null +++ b/mindarmour/natural_robustness/blur.py @@ -0,0 +1,193 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image Blur +""" + +import numpy as np +import cv2 + +from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb +from mindarmour.utils._check_param import check_param_multi_types, check_int_positive, check_param_type +from mindarmour.utils.logger import LogUtil + +LOGGER = LogUtil.get_instance() +TAG = 'Image Blur' + + +class GaussianBlur(_NaturalPerturb): + """ + Blurs the image using Gaussian blur filter. + + Args: + ksize (int): Size of gaussian kernel, this value must be non-negnative. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> ksize = 5 + >>> trans = GaussianBlur(ksize) + >>> dst = trans(img) + """ + + def __init__(self, ksize=2, auto_param=False): + super(GaussianBlur, self).__init__() + ksize = check_int_positive('ksize', ksize) + if auto_param: + ksize = 2 * np.random.randint(0, 5) + 1 + else: + ksize = 2 * ksize + 1 + self.ksize = (ksize, ksize) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + new_img = cv2.GaussianBlur(image, self.ksize, 0) + new_img = self._original_format(new_img, chw, normalized, gray3dim) + return new_img.astype(ori_dtype) + + +class MotionBlur(_NaturalPerturb): + """ + Motion blur for a given image. + + Args: + degree (int): Degree of blur. This value must be positive. Suggested value range in [1, 15]. + angle: (union[float, int]): Direction of motion blur. Angle=0 means up and down motion blur. Angle is + counterclockwise. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> angle = 0 + >>> degree = 5 + >>> trans = MotionBlur(degree=degree, angle=angle) + >>> new_img = trans(img) + """ + + def __init__(self, degree=5, angle=45, auto_param=False): + super(MotionBlur, self).__init__() + self.degree = check_int_positive('degree', degree) + self.degree = check_param_multi_types('degree', degree, [float, int]) + auto_param = check_param_type('auto_param', auto_param, bool) + if auto_param: + self.degree = np.random.randint(1, 5) + self.angle = np.random.uniform(0, 360) + else: + self.angle = angle - 45 + + def __call__(self, image): + """ + Motion blur for a given image. + + Args: + image (numpy.ndarray): Original image. + + Returns: + numpy.ndarray, image after motion blur. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + matrix = cv2.getRotationMatrix2D((self.degree / 2, self.degree / 2), self.angle, 1) + motion_blur_kernel = np.diag(np.ones(self.degree)) + motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, matrix, (self.degree, self.degree)) + motion_blur_kernel = motion_blur_kernel / self.degree + blurred = cv2.filter2D(image, -1, motion_blur_kernel) + # convert to uint8 + cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX) + blurred = self._original_format(blurred, chw, normalized, gray3dim) + + return blurred.astype(ori_dtype) + + +class GradientBlur(_NaturalPerturb): + """ + Gradient blur. + + Args: + point (union[tuple, list]): 2D coordinate of the Blur center point. + kernel_num (int): Number of blur kernels. Suggested value range in [1, 8]. + center (bool): Blurred or clear at the center of a specified point. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('xx.png') + >>> img = np.array(img) + >>> number = 5 + >>> h, w = img.shape[:2] + >>> point = (int(h / 5), int(w / 5)) + >>> center = True + >>> trans = GradientBlur(point, number, center) + >>> new_img = trans(img) + """ + + def __init__(self, point, kernel_num=3, center=True, auto_param=False): + super(GradientBlur).__init__() + point = check_param_multi_types('point', point, [list, tuple]) + self.auto_param = check_param_type('auto_param', auto_param, bool) + self.point = tuple(point) + self.kernel_num = check_int_positive('kernel_num', kernel_num) + self.center = check_param_type('center', center, bool) + + def _auto_param(self, h, w): + self.point = (int(np.random.uniform(0, h)), int(np.random.uniform(0, w))) + self.kernel_num = np.random.randint(1, 6) + self.center = np.random.choice([True, False]) + + def __call__(self, image): + """ + Args: + image (numpy.ndarray): Original image. + + Returns: + numpy.ndarray, gradient blurred image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + w, h = image.shape[:2] + if self.auto_param: + self._auto_param(h, w) + mask = np.zeros(image.shape, dtype=np.uint8) + masks = [] + radius = max(w - self.point[0], self.point[0], h - self.point[1], self.point[1]) + radius = int(radius / self.kernel_num) + for i in range(self.kernel_num): + circle = cv2.circle(mask.copy(), self.point, radius * (1 + i), (1, 1, 1), -1) + masks.append(circle) + blurs = [] + for i in range(3, 3 + 2 * self.kernel_num, 2): + ksize = (i, i) + blur = cv2.GaussianBlur(image, ksize, 0) + blurs.append(blur) + + dst = image.copy() + if self.center: + for i in range(self.kernel_num): + dst = masks[i] * dst + (1 - masks[i]) * blurs[i] + else: + for i in range(self.kernel_num - 1, -1, -1): + dst = masks[i] * blurs[self.kernel_num - 1 - i] + (1 - masks[i]) * dst + dst = self._original_format(dst, chw, normalized, gray3dim) + return dst.astype(ori_dtype) diff --git a/mindarmour/natural_robustness/corruption.py b/mindarmour/natural_robustness/corruption.py new file mode 100644 index 0000000..6a563fc --- /dev/null +++ b/mindarmour/natural_robustness/corruption.py @@ -0,0 +1,251 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image corruption. +""" +import math +import numpy as np +import cv2 + +from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb +from mindarmour.utils._check_param import check_param_multi_types, check_param_type +from mindarmour.utils.logger import LogUtil + +LOGGER = LogUtil.get_instance() +TAG = 'Image corruption' + + +class UniformNoise(_NaturalPerturb): + """ + Add uniform noise of an image. + + Args: + factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in + [0.001, 0.15]. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> factor = 0.1 + >>> trans = UniformNoise(factor) + >>> dst = trans(img) + """ + + def __init__(self, factor=0.1, auto_param=False): + super(UniformNoise, self).__init__() + self.factor = check_param_multi_types('factor', factor, [int, float]) + check_param_type('auto_param', auto_param, bool) + if auto_param: + self.factor = np.random.uniform(0, 0.15) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + low, high = (0, 255) + weight = self.factor * (high - low) + noise = np.random.uniform(-weight, weight, size=image.shape) + trans_image = np.clip(image + noise, low, high) + trans_image = self._original_format(trans_image, chw, normalized, gray3dim) + + return trans_image.astype(ori_dtype) + + +class GaussianNoise(_NaturalPerturb): + """ + Add gaussian noise of an image. + + Args: + factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in + [0.001, 0.15]. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> factor = 0.1 + >>> trans = GaussianNoise(factor) + >>> dst = trans(img) + """ + + def __init__(self, factor=0.1, auto_param=False): + super(GaussianNoise, self).__init__() + self.factor = check_param_multi_types('factor', factor, [int, float]) + check_param_type('auto_param', auto_param, bool) + if auto_param: + self.factor = np.random.uniform(0, 0.15) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + ori_dtype = image.dtype + low, high = (0, 255) + _, chw, normalized, gray3dim, image = self._check(image) + std = self.factor / math.sqrt(3) * (high - low) + noise = np.random.normal(scale=std, size=image.shape) + trans_image = np.clip(image + noise, low, high) + trans_image = self._original_format(trans_image, chw, normalized, gray3dim) + return trans_image.astype(ori_dtype) + + +class SaltAndPepperNoise(_NaturalPerturb): + """ + Add salt and pepper noise of an image. + + Args: + factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in + [0.001, 0.15]. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> factor = 0.1 + >>> trans = SaltAndPepperNoise(factor) + >>> dst = trans(img) + """ + + def __init__(self, factor=0, auto_param=False): + super(SaltAndPepperNoise, self).__init__() + self.factor = check_param_multi_types('factor', factor, [int, float]) + check_param_type('auto_param', auto_param, bool) + if auto_param: + self.factor = np.random.uniform(0, 0.15) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + low, high = (0, 255) + noise = np.random.uniform(low=-1, high=1, size=(image.shape[0], image.shape[1])) + trans_image = np.copy(image) + threshold = 1 - self.factor + trans_image[noise < -threshold] = low + trans_image[noise > threshold] = high + trans_image = self._original_format(trans_image, chw, normalized, gray3dim) + return trans_image.astype(ori_dtype) + + +class NaturalNoise(_NaturalPerturb): + """ + Add natural noise to an image. + + Args: + ratio (float): Noise density, the proportion of noise blocks per unit pixel area. Suggested value range in + [0.00001, 0.001]. + k_x_range (union[list, tuple]): Value range of the noise block length. + k_y_range (union[list, tuple]): Value range of the noise block width. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Examples: + >>> img = cv2.imread('xx.png') + >>> img = np.array(img) + >>> ratio = 0.0002 + >>> k_x_range = (1, 5) + >>> k_y_range = (3, 25) + >>> trans = NaturalNoise(ratio, k_x_range, k_y_range) + >>> new_img = trans(img) + """ + + def __init__(self, ratio=0.0002, k_x_range=(1, 5), k_y_range=(3, 25), auto_param=False): + super(NaturalNoise).__init__() + self.ratio = check_param_type('ratio', ratio, float) + k_x_range = check_param_multi_types('k_x_range', k_x_range, [list, tuple]) + k_y_range = check_param_multi_types('k_y_range', k_y_range, [list, tuple]) + self.k_x_range = tuple(k_x_range) + self.k_y_range = tuple(k_y_range) + self.auto_param = check_param_type('auto_param', auto_param, bool) + + def __call__(self, image): + """ + Add natural noise to given image. + + Args: + image (numpy.ndarray): Original image. + + Returns: + numpy.ndarray, image with natural noise. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + randon_range = 100 + w, h = image.shape[:2] + channel = len(np.shape(image)) + + if self.auto_param: + self.ratio = np.random.uniform(0, 0.001) + self.k_x_range = (1, 0.1 * w) + self.k_y_range = (1, 0.1 * h) + + for _ in range(5): + if channel == 3: + noise = np.ones((w, h, 3), dtype=np.uint8) * 255 + dst = np.ones((w, h, 3), dtype=np.uint8) * 255 + else: + noise = np.ones((w, h), dtype=np.uint8) * 255 + dst = np.ones((w, h), dtype=np.uint8) * 255 + + rate = self.ratio / 5 + mask = np.random.uniform(size=(w, h)) < rate + noise[mask] = np.random.randint(0, randon_range) + + k_x, k_y = np.random.randint(*self.k_x_range), np.random.randint(*self.k_y_range) + kernel = np.ones((k_x, k_y), np.uint8) + erode = cv2.erode(noise, kernel, iterations=1) + dst = erode * (erode < randon_range) + dst * (1 - erode < randon_range) + # Add black point + for _ in range(np.random.randint(math.ceil(k_x * k_y / 2))): + x = np.random.randint(-k_x, k_x) + y = np.random.randint(-k_y, k_y) + matrix = np.array([[1, 0, y], [0, 1, x]], dtype=np.float) + affine = cv2.warpAffine(noise, matrix, (h, w)) + dst = affine * (affine < randon_range) + dst * (1 - affine < randon_range) + # Add white point + for _ in range(int(k_x * k_y / 2)): + x = np.random.randint(-k_x / 2 - 1, k_x / 2 + 1) + y = np.random.randint(-k_y / 2 - 1, k_y / 2 + 1) + matrix = np.array([[1, 0, y], [0, 1, x]], dtype=np.float) + affine = cv2.warpAffine(noise, matrix, (h, w)) + white = affine < randon_range + dst[white] = 255 + + mask = dst < randon_range + dst = image * (1 - mask) + dst * mask + dst = self._original_format(dst, chw, normalized, gray3dim) + + return dst.astype(ori_dtype) diff --git a/mindarmour/natural_robustness/luminance.py b/mindarmour/natural_robustness/luminance.py new file mode 100644 index 0000000..2311e23 --- /dev/null +++ b/mindarmour/natural_robustness/luminance.py @@ -0,0 +1,287 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image luminance. +""" +import math +import numpy as np +import cv2 + +from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb +from mindarmour.utils._check_param import check_param_multi_types, check_param_in_range, check_param_type, \ + check_value_non_negative +from mindarmour.utils.logger import LogUtil + +LOGGER = LogUtil.get_instance() +TAG = 'Image Luminance' + + +class Contrast(_NaturalPerturb): + """ + Contrast of an image. + + Args: + alpha (Union[float, int]): Control the contrast of an image. :math:`out_image = in_image*alpha+beta`. + Suggested value range in [0.2, 2]. + beta (Union[float, int]): Delta added to alpha. Default: 0. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> alpha = 0.1 + >>> beta = 1 + >>> trans = Contrast(alpha, beta) + >>> dst = trans(img) + """ + + def __init__(self, alpha=1, beta=0, auto_param=False): + super(Contrast, self).__init__() + self.alpha = check_param_multi_types('factor', alpha, [int, float]) + self.beta = check_param_multi_types('factor', beta, [int, float]) + auto_param = check_param_type('auto_param', auto_param, bool) + if auto_param: + self.alpha = np.random.uniform(0.2, 2) + self.beta = np.random.uniform(-20, 20) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + dst = cv2.convertScaleAbs(image, alpha=self.alpha, beta=self.beta) + dst = self._original_format(dst, chw, normalized, gray3dim) + return dst.astype(ori_dtype) + + +def _circle_gradient_mask(img_src, color_start, color_end, scope=0.5, point=None): + """ + Generate circle gradient mask. + + Args: + img_src (numpy.ndarray): Source image. + color_start (union([tuple, list])): Color of circle gradient center. + color_end (union([tuple, list])): Color of circle gradient edge. + scope (float): Range of the gradient. A larger value indicates a larger gradient range. + point (union([tuple, list]): Gradient center point. + + Returns: + numpy.ndarray, gradients mask. + """ + if not isinstance(img_src, np.ndarray): + raise TypeError('`src` must be numpy.ndarray type, but got {0}.'.format(type(img_src))) + + shape = img_src.shape + height, width = shape[:2] + rgb = False + if len(shape) == 3: + rgb = True + if point is None: + point = (height // 2, width // 2) + x, y = point + + # upper left + bound_upper_left = math.ceil(math.sqrt(x ** 2 + y ** 2)) + # upper right + bound_upper_right = math.ceil(math.sqrt(height ** 2 + (width - y) ** 2)) + # lower left + bound_lower_left = math.ceil(math.sqrt((height - x) ** 2 + y ** 2)) + # lower right + bound_lower_right = math.ceil(math.sqrt((height - x) ** 2 + (width - y) ** 2)) + + radius = max(bound_lower_left, bound_lower_right, bound_upper_left, bound_upper_right) * scope + + img_grad = np.ones_like(img_src, dtype=np.uint8) * max(color_end) + # opencv use BGR format + grad_b = float(color_end[0] - color_start[0]) / radius + grad_g = float(color_end[1] - color_start[1]) / radius + grad_r = float(color_end[2] - color_start[2]) / radius + + for i in range(height): + for j in range(width): + distance = math.ceil(math.sqrt((x - i) ** 2 + (y - j) ** 2)) + if distance >= radius: + continue + if rgb: + img_grad[i, j, 0] = color_start[0] + distance * grad_b + img_grad[i, j, 1] = color_start[1] + distance * grad_g + img_grad[i, j, 2] = color_start[2] + distance * grad_r + else: + img_grad[i, j] = color_start[0] + distance * grad_b + + return img_grad.astype(np.uint8) + + +def _line_gradient_mask(image, start_pos=None, start_color=(0, 0, 0), end_color=(255, 255, 255), mode='horizontal'): + """ + Generate liner gradient mask. + + Args: + image (numpy.ndarray): Original image. + start_pos (union[tuple, list]): 2D coordinate of gradient center. + start_color (union([tuple, list])): Color of circle gradient center. + end_color (union([tuple, list])): Color of circle gradient edge. + mode (str): Direction of gradient. Optional value is 'vertical' or 'horizontal'. + + Returns: + numpy.ndarray, gradients mask. + """ + shape = image.shape + h, w = shape[:2] + rgb = False + if len(shape) == 3: + rgb = True + if start_pos is None: + start_pos = 0.5 + else: + if mode == 'horizontal': + if start_pos[0] > h: + start_pos = 1 + else: + start_pos = start_pos[0] / h + else: + if start_pos[1] > w: + start_pos = 1 + else: + start_pos = start_pos[1] / w + start_color = np.array(start_color) + end_color = np.array(end_color) + if mode == 'horizontal': + w_l = int(w * start_pos) + w_r = w - w_l + if w_l > w_r: + r_end_color = (end_color - start_color) / start_pos * (1 - start_pos) + start_color + left = np.linspace(end_color, start_color, w_l) + right = np.linspace(start_color, r_end_color, w_r) + else: + l_end_color = (end_color - start_color) / (1 - start_pos) * start_pos + start_color + left = np.linspace(l_end_color, start_color, w_l) + right = np.linspace(start_color, end_color, w_r) + line = np.concatenate((left, right), axis=0) + mask = np.reshape(np.tile(line, (h, 1)), (h, w, 3)) + else: + # 'vertical' + h_t = int(h * start_pos) + h_b = h - h_t + if h_t > h_b: + b_end_color = (end_color - start_color) / start_pos * (1 - start_pos) + start_color + top = np.linspace(end_color, start_color, h_t) + bottom = np.linspace(start_color, b_end_color, h_b) + else: + t_end_color = (end_color - start_color) / (1 - start_pos) * start_pos + start_color + top = np.linspace(t_end_color, start_color, h_t) + bottom = np.linspace(start_color, end_color, h_b) + line = np.concatenate((top, bottom), axis=0) + mask = np.reshape(np.tile(line, (w, 1)), (w, h, 3)) + mask = np.transpose(mask, [1, 0, 2]) + if not rgb: + mask = mask[:, :, 0] + return mask.astype(np.uint8) + + +class GradientLuminance(_NaturalPerturb): + """ + Gradient adjusts the luminance of picture. + + Args: + color_start (union[tuple, list]): Color of gradient center. Default:(0, 0, 0). + color_end (union[tuple, list]): Color of gradient edge. Default:(255, 255, 255). + start_point (union[tuple, list]): 2D coordinate of gradient center. + scope (float): Range of the gradient. A larger value indicates a larger gradient range. Default: 0.3. + pattern (str): Dark or light, this value must be in ['light', 'dark']. + bright_rate (float): Control brightness of . A larger value indicates a larger gradient range. If parameter + 'pattern' is 'light', Suggested value range in [0.1, 0.7], if parameter 'pattern' is 'dark', Suggested value + range in [0.1, 0.9]. + mode (str): Gradient mode, value must be in ['circle', 'horizontal', 'vertical']. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Examples: + >>> img = cv2.imread('x.png') + >>> height, width = img.shape[:2] + >>> point = (height // 4, width // 2) + >>> start = (255, 255, 255) + >>> end = (0, 0, 0) + >>> scope = 0.3 + >>> pattern='light' + >>> bright_rate = 0.3 + >>> trans = GradientLuminance(start, end, point, scope, pattern, bright_rate, mode='circle') + >>> img_new = trans(img) + """ + + def __init__(self, color_start=(0, 0, 0), color_end=(255, 255, 255), start_point=(10, 10), scope=0.5, + pattern='light', bright_rate=0.3, mode='circle', auto_param=False): + super(GradientLuminance, self).__init__() + self.color_start = check_param_multi_types('color_start', color_start, [list, tuple]) + self.color_end = check_param_multi_types('color_end', color_end, [list, tuple]) + self.start_point = check_param_multi_types('start_point', start_point, [list, tuple]) + self.scope = check_value_non_negative('scope', scope) + self.bright_rate = check_param_type('bright_rate', bright_rate, float) + self.bright_rate = check_param_in_range('bright_rate', bright_rate, 0, 1) + self.auto_param = check_param_type('auto_param', auto_param, bool) + + if pattern in ['light', 'dark']: + self.pattern = pattern + else: + msg = "Value of param pattern must be in ['light', 'dark']" + LOGGER.error(TAG, msg) + raise ValueError(msg) + if mode in ['circle', 'horizontal', 'vertical']: + self.mode = mode + else: + msg = "Value of param mode must be in ['circle', 'horizontal', 'vertical']" + LOGGER.error(TAG, msg) + raise ValueError(msg) + + def _set_auto_param(self, w, h): + self.color_start = (np.random.uniform(0, 255),) * 3 + self.color_end = (np.random.uniform(0, 255),) * 3 + self.start_point = (np.random.uniform(0, w), np.random.uniform(0, h)) + self.scope = np.random.uniform(0, 1) + self.bright_rate = np.random.uniform(0.1, 0.9) + self.pattern = np.random.choice(['light', 'dark']) + self.mode = np.random.choice(['circle', 'horizontal', 'vertical']) + + def __call__(self, image): + """ + Gradient adjusts the luminance of picture. + + Args: + image (numpy.ndarray): Original image. + + Returns: + numpy.ndarray, image with perlin noise. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + w, h = image.shape[:2] + if self.auto_param: + self._set_auto_param(w, h) + if self.mode == 'circle': + mask = _circle_gradient_mask(image, self.color_start, self.color_end, self.scope, self.start_point) + else: + mask = _line_gradient_mask(image, self.start_point, self.color_start, self.color_end, mode=self.mode) + + if self.pattern == 'light': + img_new = cv2.addWeighted(image, 1, mask, self.bright_rate, 0.0) + else: + img_new = cv2.addWeighted(image, self.bright_rate, mask, 1 - self.bright_rate, 0.0) + img_new = self._original_format(img_new, chw, normalized, gray3dim) + return img_new.astype(ori_dtype) diff --git a/mindarmour/natural_robustness/natural_perturb.py b/mindarmour/natural_robustness/natural_perturb.py new file mode 100644 index 0000000..db645dc --- /dev/null +++ b/mindarmour/natural_robustness/natural_perturb.py @@ -0,0 +1,159 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Base class for image natural perturbation. +""" +import numpy as np + +from mindspore.dataset.vision.py_transforms_util import is_numpy, hwc_to_chw +from mindarmour.utils.logger import LogUtil + +LOGGER = LogUtil.get_instance() +TAG = 'Image Transformation' + + +def _chw_to_hwc(img): + """ + Transpose the input image; shape (C, H, W) to shape (H, W, C). + + Args: + img (numpy.ndarray): Image to be converted. + + Returns: + img (numpy.ndarray), Converted image. + """ + if is_numpy(img): + return img.transpose(1, 2, 0).copy() + raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) + + +def _is_hwc(img): + """ + Check if the input image is shape (H, W, C). + + Args: + img (numpy.ndarray): Image to be checked. + + Returns: + Bool, True if input is shape (H, W, C). + """ + if is_numpy(img): + img_shape = np.shape(img) + if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3: + return True + return False + raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) + + +def _is_chw(img): + """ + Check if the input image is shape (H, W, C). + + Args: + img (numpy.ndarray): Image to be checked. + + Returns: + Bool, True if input is shape (H, W, C). + """ + if is_numpy(img): + img_shape = np.shape(img) + if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3: + return True + return False + raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) + + +def _is_rgb(img): + """ + Check if the input image is RGB. + + Args: + img (numpy.ndarray): Image to be checked. + + Returns: + Bool, True if input is RGB. + """ + if is_numpy(img): + img_shape = np.shape(img) + if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3): + return True + return False + raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) + + +def _is_normalized(img): + """ + Check if the input image is normalized between 0 to 1. + + Args: + img (numpy.ndarray): Image to be checked. + + Returns: + Bool, True if input is normalized between 0 to 1. + """ + if is_numpy(img): + minimal = np.min(img) + maximum = np.max(img) + if minimal >= 0 and maximum <= 1: + return True + return False + raise TypeError('img should be Numpy array. Got {}'.format(type(img))) + + +class _NaturalPerturb: + """ + The abstract base class for all image natural perturbation classes. + """ + + def __init__(self): + pass + + def _check(self, image): + """ Check image format. If input image is RGB and its shape + is (C, H, W), it will be transposed to (H, W, C). If the value + of the image is not normalized , it will be rescaled between 0 to 255.""" + rgb = _is_rgb(image) + chw = False + gray3dim = False + normalized = _is_normalized(image) + if rgb: + chw = _is_chw(image) + if chw: + image = _chw_to_hwc(image) + else: + image = image + else: + if len(np.shape(image)) == 3: + gray3dim = True + image = image[0] + else: + image = image + if normalized: + image = image * 255 + return rgb, chw, normalized, gray3dim, np.uint8(image) + + def _original_format(self, image, chw, normalized, gray3dim): + """ Return image with original format. """ + if not is_numpy(image): + image = np.array(image) + if chw: + image = hwc_to_chw(image) + if normalized: + image = image / 255 + if gray3dim: + image = np.expand_dims(image, 0) + return image + + def __call__(self, image): + pass diff --git a/mindarmour/natural_robustness/transformation.py b/mindarmour/natural_robustness/transformation.py new file mode 100644 index 0000000..589b7db --- /dev/null +++ b/mindarmour/natural_robustness/transformation.py @@ -0,0 +1,365 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image transformation. +""" +import math +import numpy as np +import cv2 + +from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb +from mindarmour.utils._check_param import check_param_multi_types, check_param_type, check_value_non_negative +from mindarmour.utils.logger import LogUtil + +LOGGER = LogUtil.get_instance() +TAG = 'Image Transformation' + + +class Translate(_NaturalPerturb): + """ + Translate an image. + + Args: + x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length. Suggested value range + in [-0.1, 0.1]. + y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide. Suggested value range + in [-0.1, 0.1]. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> x_bias = 0.1 + >>> y_bias = 0.1 + >>> trans = Translate(x_bias, y_bias) + >>> dst = trans(img) + """ + + def __init__(self, x_bias=0, y_bias=0, auto_param=False): + super(Translate, self).__init__() + self.x_bias = check_param_multi_types('x_bias', x_bias, [int, float]) + self.y_bias = check_param_multi_types('y_bias', y_bias, [int, float]) + if auto_param: + self.x_bias = np.random.uniform(-0.1, 0.1) + self.y_bias = np.random.uniform(-0.1, 0.1) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + h, w = image.shape[:2] + matrix = np.array([[1, 0, self.x_bias * w], [0, 1, self.y_bias * h]], dtype=np.float) + new_img = cv2.warpAffine(image, matrix, (w, h)) + new_img = self._original_format(new_img, chw, normalized, gray3dim) + return new_img.astype(ori_dtype) + + +class Scale(_NaturalPerturb): + """ + Scale an image in the middle. + + Args: + factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. Suggested value range in [0.5, 1] and + abs(factor_y - factor_x) < 0.5. + factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. Suggested value range in [0.5, 1] and + abs(factor_y - factor_x) < 0.5. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> factor_x = 0.7 + >>> factor_y = 0.6 + >>> trans = Scale(factor_x, factor_y) + >>> dst = trans(img) + """ + + def __init__(self, factor_x=1, factor_y=1, auto_param=False): + super(Scale, self).__init__() + self.factor_x = check_param_multi_types('factor_x', factor_x, [int, float]) + self.factor_y = check_param_multi_types('factor_y', factor_y, [int, float]) + auto_param = check_param_type('auto_param', auto_param, bool) + if auto_param: + self.factor_x = np.random.uniform(0.5, 1) + self.factor_y = np.random.uniform(0.5, 1) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + h, w = image.shape[:2] + matrix = np.array([[self.factor_x, 0, 0], [0, self.factor_y, 0]], dtype=np.float) + new_img = cv2.warpAffine(image, matrix, (w, h)) + new_img = self._original_format(new_img, chw, normalized, gray3dim) + return new_img.astype(ori_dtype) + + +class Shear(_NaturalPerturb): + """ + Shear an image, for each pixel (x, y) in the sheared image, the new value is taken from a position + (x+factor_x*y, factor_y*x+y) in the origin image. Then the sheared image will be rescaled to fit original size. + + Args: + factor (Union[float, int]): Shear rate in shear direction. Suggested value range in [0.05, 0.5]. + direction (str): Direction of deformation. Optional value is 'vertical' or 'horizontal'. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> factor = 0.2 + >>> trans = Shear(factor, direction='horizontal') + >>> dst = trans(img) + """ + + def __init__(self, factor=0.2, direction='horizontal', auto_param=False): + super(Shear, self).__init__() + self.factor = check_param_multi_types('factor', factor, [int, float]) + if direction not in ['horizontal', 'vertical']: + msg = "'direction must be in ['horizontal', 'vertical'], but got {}".format(direction) + raise ValueError(msg) + self.direction = direction + auto_param = check_param_type('auto_params', auto_param, bool) + if auto_param: + self.factor = np.random.uniform(0.05, 0.5) + self.direction = np.random.choice(['horizontal', 'vertical']) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + h, w = image.shape[:2] + if self.direction == 'horizontal': + matrix = np.array([[1, self.factor, 0], [0, 1, 0]], dtype=np.float) + nw = int(w + self.factor * h) + nh = h + else: + matrix = np.array([[1, 0, 0], [self.factor, 1, 0]], dtype=np.float) + nw = w + nh = int(h + self.factor * w) + new_img = cv2.warpAffine(image, matrix, (nw, nh)) + new_img = cv2.resize(new_img, (w, h)) + new_img = self._original_format(new_img, chw, normalized, gray3dim) + return new_img.astype(ori_dtype) + + +class Rotate(_NaturalPerturb): + """ + Rotate an image of counter clockwise around its center. + + Args: + angle (Union[float, int]): Degrees of counter clockwise. Suggested value range in [-60, 60]. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> angle = 20 + >>> trans = Rotate(angle) + >>> dst = trans(img) + """ + + def __init__(self, angle=20, auto_param=False): + super(Rotate, self).__init__() + self.angle = check_param_multi_types('angle', angle, [int, float]) + auto_param = check_param_type('auto_param', auto_param, bool) + if auto_param: + self.angle = np.random.uniform(0, 360) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, rotated image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + h, w = image.shape[:2] + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, -self.angle, 1.0) + cos = np.abs(matrix[0, 0]) + sin = np.abs(matrix[0, 1]) + + # Calculate new edge after rotated + nw = int((h * sin) + (w * cos)) + nh = int((h * cos) + (w * sin)) + # Adjust move distance of rotate matrix. + matrix[0, 2] += (nw / 2) - center[0] + matrix[1, 2] += (nh / 2) - center[1] + rotate = cv2.warpAffine(image, matrix, (nw, nh)) + rotate = cv2.resize(rotate, (w, h)) + new_img = self._original_format(rotate, chw, normalized, gray3dim) + return new_img.astype(ori_dtype) + + +class Perspective(_NaturalPerturb): + """ + Perform perspective transformation on a given picture. + + Args: + ori_pos (list): Four points in original image. + dst_pos (list): The point coordinates of the 4 points in ori_pos after perspective transformation. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Example: + >>> img = cv2.imread('1.png') + >>> img = np.array(img) + >>> ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]] + >>> dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]] + >>> trans = Perspective(ori_pos, dst_pos) + >>> dst = trans(img) + """ + + def __init__(self, ori_pos, dst_pos, auto_param=False): + super(Perspective, self).__init__() + ori_pos = check_param_type('ori_pos', ori_pos, list) + dst_pos = check_param_type('dst_pos', dst_pos, list) + self.ori_pos = np.float32(ori_pos) + self.dst_pos = np.float32(dst_pos) + self.auto_param = check_param_type('auto_param', auto_param, bool) + + def _set_auto_param(self, w, h): + self.ori_pos = [[h * 0.25, w * 0.25], [h * 0.25, w * 0.75], [h * 0.75, w * 0.25], [h * 0.75, w * 0.75]] + self.dst_pos = [[np.random.uniform(0, h * 0.5), np.random.uniform(0, w * 0.5)], + [np.random.uniform(0, h * 0.5), np.random.uniform(w * 0.5, w)], + [np.random.uniform(h * 0.5, h), np.random.uniform(0, w * 0.5)], + [np.random.uniform(h * 0.5, h), np.random.uniform(w * 0.5, w)]] + self.ori_pos = np.float32(self.ori_pos) + self.dst_pos = np.float32(self.dst_pos) + + def __call__(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + h, w = image.shape[:2] + if self.auto_param: + self._set_auto_param(w, h) + matrix = cv2.getPerspectiveTransform(self.ori_pos, self.dst_pos) + new_img = cv2.warpPerspective(image, matrix, (w, h)) + new_img = self._original_format(new_img, chw, normalized, gray3dim) + return new_img.astype(ori_dtype) + + +class Curve(_NaturalPerturb): + """ + Curve picture using sin method. + + Args: + curves (union[float, int]): Divide width to curves of `2*math.pi`, which means how many curve cycles. Suggested + value range in [0.1. 5]. + depth (union[float, int]): Amplitude of sin method. Suggested value not exceed 1/10 of the length of the + picture. + mode (str): Direction of deformation. Optional value is 'vertical' or 'horizontal'. + auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image. + + Examples: + >>> img = cv2.imread('x.png') + >>> curves =1 + >>> depth = 10 + >>> trans = Curve(curves, depth, mode='vertical') + >>> img_new = trans(img) + """ + + def __init__(self, curves=3, depth=10, mode='vertical', auto_param=False): + super(Curve).__init__() + self.curves = check_value_non_negative('curves', curves) + self.depth = check_value_non_negative('depth', depth) + if mode in ['vertical', 'horizontal']: + self.mode = mode + else: + msg = "Value of param mode must be in ['vertical', 'horizontal']" + LOGGER.error(TAG, msg) + raise ValueError(msg) + self.auto_param = check_param_type('auto_param', auto_param, bool) + + def _set_auto_params(self, height, width): + if self.auto_param: + self.curves = np.random.uniform(1, 5) + self.mode = np.random.choice(['vertical', 'horizontal']) + if self.mode == 'vertical': + self.depth = np.random.uniform(1, 0.1 * width) + else: + self.depth = np.random.uniform(1, 0.1 * height) + + def __call__(self, image): + """ + Curve picture using sin method. + + Args: + image (numpy.ndarray): Original image. + + Returns: + numpy.ndarray, curved image. + """ + ori_dtype = image.dtype + _, chw, normalized, gray3dim, image = self._check(image) + shape = image.shape + height, width = shape[:2] + if self.mode == 'vertical': + if len(shape) == 3: + image = np.transpose(image, [1, 0, 2]) + else: + image = np.transpose(image, [1, 0]) + src_x = np.zeros((height, width), np.float32) + src_y = np.zeros((height, width), np.float32) + + for y in range(height): + for x in range(width): + src_x[y, x] = x + src_y[y, x] = y + self.depth * math.sin(x / (width / self.curves / 2 / math.pi)) + img_new = cv2.remap(image, src_x, src_y, cv2.INTER_LINEAR) + + if self.mode == 'vertical': + if len(shape) == 3: + img_new = np.transpose(img_new, [1, 0, 2]) + else: + img_new = np.transpose(image, [1, 0]) + new_img = self._original_format(img_new, chw, normalized, gray3dim) + return new_img.astype(ori_dtype) diff --git a/mindarmour/reliability/model_fault_injection/__init__.py b/mindarmour/reliability/model_fault_injection/__init__.py index f2c861a..36f43bc 100644 --- a/mindarmour/reliability/model_fault_injection/__init__.py +++ b/mindarmour/reliability/model_fault_injection/__init__.py @@ -1,4 +1,5 @@ # Copyright 2021 Huawei Technologies Co., Ltd +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index e3cd17b..1407a95 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -183,7 +183,7 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels): def check_equal_length(para_name1, value1, para_name2, value2): """Check weather the two parameters have equal length.""" if len(value1) != len(value2): - msg = 'The dimension of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'\ + msg = 'The dimension of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}' \ .format(para_name1, para_name2, len(value1), len(value2)) LOGGER.error(TAG, msg) raise ValueError(msg) @@ -193,7 +193,7 @@ def check_equal_length(para_name1, value1, para_name2, value2): def check_equal_shape(para_name1, value1, para_name2, value2): """Check weather the two parameters have equal shape.""" if value1.shape != value2.shape: - msg = 'The shape of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'.\ + msg = 'The shape of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'. \ format(para_name1, para_name2, value1.shape, value2.shape) LOGGER.error(TAG, msg) raise ValueError(msg) @@ -204,7 +204,7 @@ def check_norm_level(norm_level): """Check norm_level of regularization.""" if not isinstance(norm_level, (int, str)): msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level)) - accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', np.inf] + accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf] if norm_level not in accept_norm: msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level) LOGGER.error(TAG, msg) @@ -224,8 +224,7 @@ def normalize_value(value, norm_level): numpy.ndarray, normalized value. Raises: - NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', - 'inf', 'l1', 'l2'] + NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', 'inf', 'l1', 'l2'] """ norm_level = check_norm_level(norm_level) ori_shape = value.shape @@ -237,7 +236,7 @@ def normalize_value(value, norm_level): elif norm_level in (2, '2', 'l2'): norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div norm_value = value_reshape / norm - elif norm_level in (np.inf, 'inf'): + elif norm_level in (np.inf, 'inf', 'np.inf', 'linf'): norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div norm_value = value_reshape / norm else: diff --git a/tests/ut/python/fuzzing/test_coverage_metrics.py b/tests/ut/python/fuzzing/test_coverage_metrics.py index 61f34b4..d05ef59 100644 --- a/tests/ut/python/fuzzing/test_coverage_metrics.py +++ b/tests/ut/python/fuzzing/test_coverage_metrics.py @@ -75,11 +75,11 @@ def test_lenet_mnist_coverage_cpu(): model = Model(net) # initialize fuzz test with training dataset - training_data = (np.random.random((10000, 10))*20).astype(np.float32) + training_data = (np.random.random((10000, 10)) * 20).astype(np.float32) # fuzz test with original test data # get test data - test_data = (np.random.random((2000, 10))*20).astype(np.float32) + test_data = (np.random.random((2000, 10)) * 20).astype(np.float32) test_labels = np.random.randint(0, 10, 2000).astype(np.int32) nc = NeuronCoverage(model, threshold=0.1) @@ -118,6 +118,7 @@ def test_lenet_mnist_coverage_cpu(): print('NC of adv data is: ', nc_metric) print('TKNC of adv data is: ', tknc_metrics) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -130,11 +131,11 @@ def test_lenet_mnist_coverage_ascend(): model = Model(net) # initialize fuzz test with training dataset - training_data = (np.random.random((10000, 10))*20).astype(np.float32) + training_data = (np.random.random((10000, 10)) * 20).astype(np.float32) # fuzz test with original test data # get test data - test_data = (np.random.random((2000, 10))*20).astype(np.float32) + test_data = (np.random.random((2000, 10)) * 20).astype(np.float32) nc = NeuronCoverage(model, threshold=0.1) nc_metric = nc.get_metrics(test_data) diff --git a/tests/ut/python/fuzzing/test_fuzzer.py b/tests/ut/python/fuzzing/test_fuzzer.py index 1d585d4..34be0c5 100644 --- a/tests/ut/python/fuzzing/test_fuzzer.py +++ b/tests/ut/python/fuzzing/test_fuzzer.py @@ -99,15 +99,17 @@ def test_fuzzing_ascend(): model = Model(net) batch_size = 8 num_classe = 10 - mutate_config = [{'method': 'Blur', - 'params': {'auto_param': [True]}}, + mutate_config = [{'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], + 'auto_param': [True, False]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, {'method': 'Contrast', - 'params': {'factor': [2, 1]}}, - {'method': 'Translate', - 'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}}, + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, {'method': 'FGSM', - 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} - ] + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) # fuzz test with original test data @@ -142,15 +144,17 @@ def test_fuzzing_cpu(): model = Model(net) batch_size = 8 num_classe = 10 - mutate_config = [{'method': 'Blur', - 'params': {'auto_param': [True]}}, + mutate_config = [{'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], + 'auto_param': [True, False]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, {'method': 'Contrast', - 'params': {'factor': [2, 1]}}, - {'method': 'Translate', - 'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}}, + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, {'method': 'FGSM', - 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} - ] + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] # initialize fuzz test with training dataset train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) diff --git a/tests/ut/python/fuzzing/test_image_transform.py b/tests/ut/python/fuzzing/test_image_transform.py deleted file mode 100644 index 9360746..0000000 --- a/tests/ut/python/fuzzing/test_image_transform.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Image transform test. -""" -import numpy as np -import pytest - -from mindarmour.utils.logger import LogUtil -from mindarmour.fuzz_testing.image_transform import Contrast, Brightness, \ - Blur, Noise, Translate, Scale, Shear, Rotate - -LOGGER = LogUtil.get_instance() -TAG = 'Image transform test' -LOGGER.set_level('INFO') - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_contrast(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Contrast() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_brightness(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Brightness() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_blur(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Blur() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_noise(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Noise() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_translate(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Translate() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_shear(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Shear() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_scale(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Scale() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_rotate(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Rotate() - trans.set_params(auto_param=True) - _ = trans.transform(image) diff --git a/tests/ut/python/natural_robustness/test_natural_robustness.py b/tests/ut/python/natural_robustness/test_natural_robustness.py new file mode 100644 index 0000000..1f4d5ff --- /dev/null +++ b/tests/ut/python/natural_robustness/test_natural_robustness.py @@ -0,0 +1,577 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example for natural robustness methods.""" + +import pytest +import numpy as np +from mindspore import context + +from mindarmour.natural_robustness import Translate, Curve, Perspective, Scale, Shear, Rotate, SaltAndPepperNoise, \ + NaturalNoise, GaussianNoise, UniformNoise, MotionBlur, GaussianBlur, GradientBlur, Contrast, GradientLuminance + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_perspective(): + """ + Feature: Test image perspective. + Description: Image will be transform for given perspective projection. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]] + dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]] + trans = Perspective(ori_pos, dst_pos) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_uniform_noise(): + """ + Feature: Test image uniform noise. + Description: Add uniform image in image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = UniformNoise(factor=0.1) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_gaussian_noise(): + """ + Feature: Test image gaussian noise. + Description: Add gaussian image in image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = GaussianNoise(factor=0.1) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_contrast(): + """ + Feature: Test image contrast. + Description: Adjust image contrast. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = Contrast(alpha=0.3, beta=0) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_gaussian_blur(): + """ + Feature: Test image gaussian blur. + Description: Add gaussian blur to image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = GaussianBlur(ksize=5) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_salt_and_pepper_noise(): + """ + Feature: Test image salt and pepper noise. + Description: Add salt and pepper to image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = SaltAndPepperNoise(factor=0.01) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_translate(): + """ + Feature: Test image translate. + Description: Translate an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = Translate(x_bias=0.1, y_bias=0.1) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_scale(): + """ + Feature: Test image scale. + Description: Scale an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = Scale(factor_x=0.7, factor_y=0.7) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_shear(): + """ + Feature: Test image shear. + Description: Shear an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = Shear(factor=0.2) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_rotate(): + """ + Feature: Test image rotate. + Description: Rotate an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = Rotate(angle=20) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_curve(): + """ + Feature: Test image curve. + Description: Transform an image with curve. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = Curve(curves=1.5, depth=1.5, mode='horizontal') + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_natural_noise(): + """ + Feature: Test natural noise. + Description: Add natural noise to an. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_gradient_luminance(): + """ + Feature: Test gradient luminance. + Description: Adjust image luminance. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + height, width = image.shape[:2] + point = (height // 4, width // 2) + start = (255, 255, 255) + end = (0, 0, 0) + scope = 0.3 + bright_rate = 0.4 + trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate, + mode='horizontal') + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_motion_blur(): + """ + Feature: Test motion blur. + Description: Add motion blur to an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + angle = -10.5 + i = 3 + trans = MotionBlur(degree=i, angle=angle) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_gradient_blur(): + """ + Feature: Test gradient blur. + Description: Add gradient blur to an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + image = np.random.random((32, 32, 3)) + number = 10 + h, w = image.shape[:2] + point = (int(h / 5), int(w / 5)) + center = False + trans = GradientBlur(point, number, center) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_perspective_ascend(): + """ + Feature: Test image perspective. + Description: Image will be transform for given perspective projection. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]] + dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]] + trans = Perspective(ori_pos, dst_pos) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_uniform_noise_ascend(): + """ + Feature: Test image uniform noise. + Description: Add uniform image in image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = UniformNoise(factor=0.1) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_gaussian_noise_ascend(): + """ + Feature: Test image gaussian noise. + Description: Add gaussian image in image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = GaussianNoise(factor=0.1) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_contrast_ascend(): + """ + Feature: Test image contrast. + Description: Adjust image contrast. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = Contrast(alpha=0.3, beta=0) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_gaussian_blur_ascend(): + """ + Feature: Test image gaussian blur. + Description: Add gaussian blur to image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = GaussianBlur(ksize=5) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_salt_and_pepper_noise_ascend(): + """ + Feature: Test image salt and pepper noise. + Description: Add salt and pepper to image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = SaltAndPepperNoise(factor=0.01) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_translate_ascend(): + """ + Feature: Test image translate. + Description: Translate an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = Translate(x_bias=0.1, y_bias=0.1) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_ascend_mindarmour +def test_scale_ascend(): + """ + Feature: Test image scale. + Description: Scale an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = Scale(factor_x=0.7, factor_y=0.7) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_shear_ascend(): + """ + Feature: Test image shear. + Description: Shear an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = Shear(factor=0.2) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_rotate_ascend(): + """ + Feature: Test image rotate. + Description: Rotate an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = Rotate(angle=20) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_curve_ascend(): + """ + Feature: Test image curve. + Description: Transform an image with curve. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = Curve(curves=1.5, depth=1.5, mode='horizontal') + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_natural_noise_ascend(): + """ + Feature: Test natural noise. + Description: Add natural noise to an. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_gradient_luminance_ascend(): + """ + Feature: Test gradient luminance. + Description: Adjust image luminance. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + height, width = image.shape[:2] + point = (height // 4, width // 2) + start = (255, 255, 255) + end = (0, 0, 0) + scope = 0.3 + bright_rate = 0.4 + trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate, + mode='horizontal') + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_motion_blur_ascend(): + """ + Feature: Test motion blur. + Description: Add motion blur to an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + angle = -10.5 + i = 3 + trans = MotionBlur(degree=i, angle=angle) + dst = trans(image) + print(dst) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_gradient_blur_ascend(): + """ + Feature: Test gradient blur. + Description: Add gradient blur to an image. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + image = np.random.random((32, 32, 3)) + number = 10 + h, w = image.shape[:2] + point = (int(h / 5), int(w / 5)) + center = False + trans = GradientBlur(point, number, center) + dst = trans(image) + print(dst)