Browse Source

real world robustness samples

pull/318/head
ZhidanLiu 3 years ago
parent
commit
14c0edc4c2
26 changed files with 3003 additions and 901 deletions
  1. +0
    -1
      .gitignore
  2. +72
    -67
      examples/ai_fuzzer/fuzz_testing_and_model_enhense.py
  3. +48
    -19
      examples/ai_fuzzer/lenet5_mnist_fuzzing.py
  4. +176
    -0
      examples/natural_robustness/natural_robustness_example.py
  5. +206
    -0
      examples/natural_robustness/serving/README.md
  6. +41
    -0
      examples/natural_robustness/serving/client/perturb_config.py
  7. +61
    -0
      examples/natural_robustness/serving/client/serving_client.py
  8. +58
    -0
      examples/natural_robustness/serving/server/export_model/add_model.py
  9. +109
    -0
      examples/natural_robustness/serving/server/perturbation/servable_config.py
  10. +35
    -0
      examples/natural_robustness/serving/server/serving_server.py
  11. +4
    -4
      mindarmour/adv_robustness/attacks/iterative_gradient_method.py
  12. +111
    -30
      mindarmour/fuzz_testing/fuzzing.py
  13. +0
    -609
      mindarmour/fuzz_testing/image_transform.py
  14. +184
    -21
      mindarmour/fuzz_testing/model_coverage_metrics.py
  15. +37
    -0
      mindarmour/natural_robustness/__init__.py
  16. +193
    -0
      mindarmour/natural_robustness/blur.py
  17. +251
    -0
      mindarmour/natural_robustness/corruption.py
  18. +287
    -0
      mindarmour/natural_robustness/luminance.py
  19. +159
    -0
      mindarmour/natural_robustness/natural_perturb.py
  20. +365
    -0
      mindarmour/natural_robustness/transformation.py
  21. +1
    -0
      mindarmour/reliability/model_fault_injection/__init__.py
  22. +5
    -6
      mindarmour/utils/_check_param.py
  23. +5
    -4
      tests/ut/python/fuzzing/test_coverage_metrics.py
  24. +18
    -14
      tests/ut/python/fuzzing/test_fuzzer.py
  25. +0
    -126
      tests/ut/python/fuzzing/test_image_transform.py
  26. +577
    -0
      tests/ut/python/natural_robustness/test_natural_robustness.py

+ 0
- 1
.gitignore View File

@@ -19,7 +19,6 @@ example/mnist_demo/model/
example/cifar_demo/model/
example/dog_cat_demo/model/
mindarmour.egg-info/
*model/
*MNIST/
*out.data/
*defensed_model/


+ 72
- 67
examples/ai_fuzzer/fuzz_testing_and_model_enhense.py View File

@@ -27,7 +27,7 @@ from mindspore.nn.optim.momentum import Momentum

from mindarmour.adv_robustness.defenses import AdversarialDefense
from mindarmour.fuzz_testing import Fuzzer
from mindarmour.fuzz_testing import ModelCoverageMetrics
from mindarmour.fuzz_testing import KMultisectionNeuronCoverage
from mindarmour.utils.logger import LogUtil

from examples.common.dataset.data_processing import generate_mnist_dataset
@@ -38,33 +38,66 @@ TAG = 'Fuzz_testing and enhance model'
LOGGER.set_level('INFO')


def split_dataset(image, label, proportion):
"""
Split the generated fuzz data into train and test set.
"""
indices = np.arange(len(image))
random.shuffle(indices)
train_length = int(len(image) * proportion)
train_image = [image[i] for i in indices[:train_length]]
train_label = [label[i] for i in indices[:train_length]]
test_image = [image[i] for i in indices[:train_length]]
test_label = [label[i] for i in indices[:train_length]]
return train_image, train_label, test_image, test_label


def example_lenet_mnist_fuzzing():
"""
An example of fuzz testing and then enhance the non-robustness model.
"""
# upload trained network
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m1-10_1250.ckpt'
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, load_dict)
model = Model(net)
mutate_config = [{'method': 'Blur',
'params': {'auto_param': [True]}},
{'method': 'Contrast',
'params': {'auto_param': [True]}},
{'method': 'Translate',
'params': {'auto_param': [True]}},
{'method': 'Brightness',
'params': {'auto_param': [True]}},
{'method': 'Noise',
'params': {'auto_param': [True]}},
{'method': 'Scale',
'params': {'auto_param': [True]}},
{'method': 'Shear',
'params': {'auto_param': [True]}},
{'method': 'FGSM',
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}}
]
mutate_config = [
{'method': 'GaussianBlur',
'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}},
{'method': 'MotionBlur',
'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}},
{'method': 'GradientBlur',
'params': {'point': [[10, 10]], 'auto_param': [True]}},
{'method': 'UniformNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'GaussianNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'SaltAndPepperNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'NaturalNoise',
'params': {'ratio': [0.1], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)], 'auto_param': [False, True]}},
{'method': 'Contrast',
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}},
{'method': 'GradientLuminance',
'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)],
'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'],
'auto_param': [False, True]}},
{'method': 'Translate',
'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}},
{'method': 'Scale',
'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}},
{'method': 'Shear',
'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}},
{'method': 'Rotate',
'params': {'angle': [20, 90], 'auto_param': [False, True]}},
{'method': 'Perspective',
'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]],
'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}},
{'method': 'Curve',
'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}},
{'method': 'FGSM',
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}]

# get training data
data_list = "../common/dataset/MNIST/train"
@@ -75,49 +108,36 @@ def example_lenet_mnist_fuzzing():
images = data[0].astype(np.float32)
train_images.append(images)
train_images = np.concatenate(train_images, axis=0)
neuron_num = 10
segmented_num = 1000

# initialize fuzz test with training dataset
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)
segmented_num = 100

# fuzz test with original test data
# get test data
data_list = "../common/dataset/MNIST/test"
batch_size = 32
init_samples = 5000
max_iters = 50000
batch_size = batch_size
init_samples = 50
max_iters = 500
mutate_num_per_seed = 10
ds = generate_mnist_dataset(data_list, batch_size, num_samples=init_samples,
sparse=False)
ds = generate_mnist_dataset(data_list, batch_size=batch_size, num_samples=init_samples, sparse=False)
test_images = []
test_labels = []
for data in ds.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
test_images.append(data[0].astype(np.float32))
test_labels.append(data[1])
test_images = np.concatenate(test_images, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
initial_seeds = []

coverage = KMultisectionNeuronCoverage(model, train_images, segmented_num=segmented_num, incremental=True)
kmnc = coverage.get_metrics(test_images[:100])
print('kmnc: ', kmnc)

# make initial seeds
initial_seeds = []
for img, label in zip(test_images, test_labels):
initial_seeds.append([img, label])

model_coverage_test.calculate_coverage(
np.array(test_images[:100]).astype(np.float32))
LOGGER.info(TAG, 'KMNC of test dataset before fuzzing is : %s',
model_coverage_test.get_kmnc())
LOGGER.info(TAG, 'NBC of test dataset before fuzzing is : %s',
model_coverage_test.get_nbc())
LOGGER.info(TAG, 'SNAC of test dataset before fuzzing is : %s',
model_coverage_test.get_snac())

model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
model_fuzz_test = Fuzzer(model)
gen_samples, gt, _, _, metrics = model_fuzz_test.fuzzing(mutate_config,
initial_seeds,
eval_metrics='auto',
initial_seeds, coverage,
evaluate=True,
max_iters=max_iters,
mutate_num_per_seed=mutate_num_per_seed)

@@ -125,24 +145,10 @@ def example_lenet_mnist_fuzzing():
for key in metrics:
LOGGER.info(TAG, key + ': %s', metrics[key])

def split_dataset(image, label, proportion):
"""
Split the generated fuzz data into train and test set.
"""
indices = np.arange(len(image))
random.shuffle(indices)
train_length = int(len(image) * proportion)
train_image = [image[i] for i in indices[:train_length]]
train_label = [label[i] for i in indices[:train_length]]
test_image = [image[i] for i in indices[:train_length]]
test_label = [label[i] for i in indices[:train_length]]
return train_image, train_label, test_image, test_label

train_image, train_label, test_image, test_label = split_dataset(
gen_samples, gt, 0.7)
train_image, train_label, test_image, test_label = split_dataset(gen_samples, gt, 0.7)

# load model B and test it on the test set
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m2-10_1250.ckpt'
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, load_dict)
@@ -154,12 +160,11 @@ def example_lenet_mnist_fuzzing():
# enhense model robustness
lr = 0.001
momentum = 0.9
loss_fn = SoftmaxCrossEntropyWithLogits(Sparse=True)
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=True)
optimizer = Momentum(net.trainable_params(), lr, momentum)

adv_defense = AdversarialDefense(net, loss_fn, optimizer)
adv_defense.batch_defense(np.array(train_image).astype(np.float32),
np.argmax(train_label, axis=1).astype(np.int32))
adv_defense.batch_defense(np.array(train_image).astype(np.float32), np.argmax(train_label, axis=1).astype(np.int32))
preds_en = net(Tensor(test_image, dtype=mindspore.float32)).asnumpy()
acc_en = np.sum(np.argmax(preds_en, axis=1) == np.argmax(test_label, axis=1)) / len(test_label)
print('Accuracy of enhensed model on test set is ', acc_en)
@@ -167,5 +172,5 @@ def example_lenet_mnist_fuzzing():

if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
example_lenet_mnist_fuzzing()

+ 48
- 19
examples/ai_fuzzer/lenet5_mnist_fuzzing.py View File

@@ -35,24 +35,50 @@ def test_lenet_mnist_fuzzing():
load_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, load_dict)
model = Model(net)
mutate_config = [{'method': 'Blur',
'params': {'radius': [0.1, 0.2, 0.3],
'auto_param': [True, False]}},
{'method': 'Contrast',
'params': {'auto_param': [True]}},
{'method': 'Translate',
'params': {'auto_param': [True]}},
{'method': 'Brightness',
'params': {'auto_param': [True]}},
{'method': 'Noise',
'params': {'auto_param': [True]}},
{'method': 'Scale',
'params': {'auto_param': [True]}},
{'method': 'Shear',
'params': {'auto_param': [True]}},
{'method': 'FGSM',
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}
]
mutate_config = [
{'method': 'GaussianBlur',
'params': {'ksize': [1, 2, 3, 5],
'auto_param': [True, False]}},
{'method': 'MotionBlur',
'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}},
{'method': 'GradientBlur',
'params': {'point': [[10, 10]], 'auto_param': [True]}},
{'method': 'UniformNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'GaussianNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'SaltAndPepperNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'NaturalNoise',
'params': {'ratio': [0.1, 0.2, 0.3], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)],
'auto_param': [False, True]}},
{'method': 'Contrast',
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}},
{'method': 'GradientLuminance',
'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)],
'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'],
'auto_param': [False, True]}},
{'method': 'Translate',
'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}},
{'method': 'Scale',
'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}},
{'method': 'Shear',
'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}},
{'method': 'Rotate',
'params': {'angle': [20, 90], 'auto_param': [False, True]}},
{'method': 'Perspective',
'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]],
'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}},
{'method': 'Curve',
'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}},
{'method': 'FGSM',
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}},
{'method': 'PGD',
'params': {'eps': [0.1, 0.2, 0.4], 'eps_iter': [0.05, 0.1], 'nb_iter': [1, 3]}},
{'method': 'MDIIM',
'params': {'eps': [0.1, 0.2, 0.4], 'prob': [0.5, 0.1],
'norm_level': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', 'linf']}}
]

# get training data
data_list = "../common/dataset/MNIST/train"
@@ -88,7 +114,10 @@ def test_lenet_mnist_fuzzing():
print('KMNC of initial seeds is: ', kmnc)
initial_seeds = initial_seeds[:100]
model_fuzz_test = Fuzzer(model)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10,
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config,
initial_seeds, coverage,
evaluate=True,
max_iters=10,
mutate_num_per_seed=20)

if metrics:


+ 176
- 0
examples/natural_robustness/natural_robustness_example.py View File

@@ -0,0 +1,176 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example for natural robustness methods."""
import numpy as np
import cv2
from mindarmour.natural_robustness import Translate, Curve, Perspective, Scale, Shear, Rotate, SaltAndPepperNoise, \
NaturalNoise, GaussianNoise, UniformNoise, MotionBlur, GaussianBlur, GradientBlur, Contrast, GradientLuminance
def test_perspective(image):
"""Test perspective."""
ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]]
dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]]
trans = Perspective(ori_pos, dst_pos)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_uniform_noise(image):
"""Test uniform noise."""
trans = UniformNoise(factor=0.1)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_gaussian_noise(image):
"""Test gaussian noise."""
trans = GaussianNoise(factor=0.1)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_contrast(image):
"""Test contrast."""
trans = Contrast(alpha=0.3, beta=0)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_gaussian_blur(image):
"""Test gaussian blur."""
trans = GaussianBlur(ksize=5)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_salt_and_pepper_noise(image):
"""Test salt and pepper noise."""
trans = SaltAndPepperNoise(factor=0.01)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_translate(image):
"""Test translate."""
trans = Translate(x_bias=0.1, y_bias=0.1)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_scale(image):
"""Test scale."""
trans = Scale(factor_x=0.7, factor_y=0.7)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_shear(image):
"""Test shear."""
trans = Shear(factor=0.2)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_rotate(image):
"""Test rotate."""
trans = Rotate(angle=20)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_curve(image):
"""Test curve."""
trans = Curve(curves=1.5, depth=1.5, mode='horizontal')
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_natural_noise(image):
"""Test natural noise."""
trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_gradient_luminance(image):
"""Test gradient luminance."""
height, width = image.shape[:2]
point = (height // 4, width // 2)
start = (255, 255, 255)
end = (0, 0, 0)
scope = 0.3
bright_rate = 0.4
trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate,
mode='horizontal')
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_motion_blur(image):
"""Test motion blur."""
angle = -10.5
i = 3
trans = MotionBlur(degree=i, angle=angle)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
def test_gradient_blur(image):
"""Test gradient blur."""
number = 10
h, w = image.shape[:2]
point = (int(h / 5), int(w / 5))
center = False
trans = GradientBlur(point, number, center)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()
if __name__ == '__main__':
img = cv2.imread('1.jpeg')
img = np.array(img)
test_uniform_noise(img)
test_gaussian_noise(img)
test_motion_blur(img)
test_gradient_blur(img)
test_gradient_luminance(img)
test_natural_noise(img)
test_curve(img)
test_rotate(img)
test_shear(img)
test_scale(img)
test_translate(img)
test_salt_and_pepper_noise(img)
test_gaussian_blur(img)
test_constract(img)
test_perspective(img)

+ 206
- 0
examples/natural_robustness/serving/README.md View File

@@ -0,0 +1,206 @@
# 自然扰动样本生成serving
提供自然扰动样本生成在线服务。客户端传入图片和扰动参数,服务端返回扰动后的图片数据。
## 环境准备
硬件环境:Ascend 910,GPU
操作系统:Linux-x86_64
软件环境:
1. python 3.7.5或python 3.9.0
2. 安装MindSpore 1.5.0可以参考[MindSpore安装页面](https://www.mindspore.cn/install)
3. 安装MindSpore Serving 1.5.0可以参考[MindSpore Serving 安装页面](https://www.mindspore.cn/serving/docs/zh-CN/r1.5/serving_install.html)
4. 安装serving分支的MindArmour:
- 从Gitee下载源码
`git clone https://gitee.com/mindspore/mindarmour.git`
- 编译并安装MindArmour
`python setup.py install`
### 文件结构说明
```bash
serving
├── server
│ ├── serving_server.py # 启动serving服务脚本
│ ├── export_model
│ │ └── add_model.py # 生成模型文件脚本
│ └── perturbation
│ └── serverable_config.py # 服务端接收客户端数据后的处理脚本
└── client
├── serving_client.py # 启动客户端脚本
└── perturb_config.py # 扰动方法配置文件
```
## 脚本说明及使用
### 导出模型
在`server/export_model`目录下,使用[add_model.py](https://gitee.com/mindspore/serving/blob/r1.5/example/tensor_add/export_model/add_model.py),构造了一个只有Add算子的tensor加法网络。使用命令
```bash
python add_model.py
```
在`perturbation`模型文件夹下生成`tensor_add.mindir`模型文件。
该服务实际上并没有使用到模型,但目前版本的serving需要有一个模型,serving升级后这部分会删除。
### 部署Serving推理服务
1. #### `servable_config.py`说明。
```python
···
# 客户端可以请求的方法,包含4个返回值:"results", "file_names", "file_length", "names_dict"
@register.register_method(output_names=["results", "file_names", "file_length", "names_dict"])
def natural_perturbation(img, perturb_config, methods_number, outputs_number):
"""method natural_perturbation data flow definition, only preprocessing and call model"""
res = register.add_stage(perturb, img, perturb_config, methods_number, outputs_number, outputs_count=4)
return res
```
方法`natural_perturbation`为对外提供服务的接口。
**输入:**
- img:输入为图片,格式为bytes。
- perturb_config:扰动配置项,具体配置参考`perturb_config.py`。
- methods_number:每次扰动随机从配置项中选择方法的个数。
- outputs_number:对于每张图片,生成的扰动图片数量。
**输出**res中包含4个参数:
- results:拼接后的图像bytes;
- file_names:图像名,格式为`xxx.png`,其中‘xxx’为A-Za-z中随机选择20个字符构成的字符串。
- file_length:每张图片的bytes长度。
- names_dict: 图片名和图片使用扰动方法构成的字典。格式为:
```bash
{
picture1.png: [[method1, parameters of method1], [method2, parameters of method2], ...]],
picture2.png: [[method3, parameters of method3], [method4, parameters of method4], ...]],
...
}
```
2. #### 启动server。
```python
···
def start():
servable_dir = os.path.dirname(os.path.realpath(sys.argv[0]))
# 服务配置
servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="perturbation", device_ids=(0, 1), num_parallel_workers=4)
# 启动服务
server.start_servables(servable_configs=servable_config)
# 启动启动gRPC服务,用于客户端和服务端之间通信
server.start_grpc_server(address="0.0.0.0:5500", max_msg_mb_size=200) # ip和最大的传输数据量,单位MB
# 启动启动Restful服务,用于客户端和服务端之间通信
server.start_restful_server(address="0.0.0.0:5500")
```
gRPC传输性能更好,Restful更适合用于web服务,根据需要选择。
执行命令`python serverong_server.py`启动服务。
当服务端打印日志`Serving RESTful server start success, listening on 0.0.0.0:5500`时,表示Serving RESTful服务启动成功,推理模型已成功加载。
### 客户端进行推理
1. 在`perturb_config.py`中设置扰动方法及参数。下面是个例子:
```python
PerturbConfig = [{"method": "Contrast", "params": {"alpha": 1.5, "beta": 0}},
{"method": "GaussianBlur", "params": {"ksize": 5}},
{"method": "SaltAndPepperNoise", "params": {"factor": 0.05}},
{"method": "Translate", "params": {"x_bias": 0.1, "y_bias": -0.2}},
{"method": "Scale", "params": {"factor_x": 0.7, "factor_y": 0.7}},
{"method": "Shear", "params": {"factor": 2, "director": "horizontal"}},
{"method": "Rotate", "params": {"angle": 40}},
{"method": "MotionBlur", "params": {"degree": 5, "angle": 45}},
{"method": "GradientBlur", "params": {"point": [50, 100], "kernel_num": 3, "center": True}},
{"method": "GradientLuminance",
"params": {"color_start": [255, 255, 255],
"color_end": [0, 0, 0],
"start_point": [100, 150], "scope": 0.3,
"bright_rate": 0.3, "pattern": "light",
"mode": "circle"}},
{"method": "Curve", "params": {"curves": 5, "depth": 10,
"mode": "vertical"}},
{"method": "Perspective",
"params": {"ori_pos": [[0, 0], [0, 800], [800, 0], [800, 800]],
"dst_pos": [[50, 0], [0, 800], [780, 0], [800, 800]]}},
]
```
其中`method`为扰动方法名,`params`为对应方法的参数。可用的扰动方法及对应参数可在`mindarmour/natural_robustness/natural_noise.py`中查询。
2. 在`serving_client.py`中写客户端的处理脚本,包含输入输出的处理、服务端的调用,可以参考下面的例子。
```python
···
def perturb(perturb_config):
"""invoke servable perturbation method natural_perturbation"""
# 请求的服务端ip及端口、请求的服务名、请求的方法名
client = Client("10.175.122.87:5500", "perturbation", "natural_perturbation")
# 输入数据
instances = []
img_path = '/root/mindarmour/example/adversarial/test_data/1.png'
result_path = '/root/mindarmour/example/adv/result/'
methods_number = 2
outputs_number = 3
img = cv2.imread(img_path)
img = cv2.imencode('.png', img)[1].tobytes() # 图片传输用bytes格式,不支持numpy.ndarray格式
perturb_config = json.dumps(perturb_config) # 配置方法转成json格式
instances.append({"img": img, 'perturb_config': perturb_config, "methods_number": methods_number,
"outputs_number": outputs_number}) # instances中可添加多个输入
# 请求服务,返回结果
result = client.infer(instances)
# 对服务请求得到的结果进行处理,将返回的图片字节流存成图片
file_names = result[0]['file_names'].split(';')
length = result[0]['file_length'].tolist()
before = 0
for name, leng in zip(file_names, length):
res_img = result[0]['results']
res_img = res_img[before:before + leng]
before = before + leng
print('name: ', name)
image = Image.open(BytesIO(res_img))
image.save(os.path.join(result_path, name))
names_dict = result[0]['names_dict']
with open('names_dict.json', 'w') as file:
file.write(names_dict)
```
启动client前,需将服务端的IP地址改成部署server的IP地址,图片路径、结果存储路基替换成用户数据路径。
目前serving数据传输支持的数据类型包括:python的int、float、bool、str、bytes,numpy number, numpy array object。
输入命令`python serving_client.py`开启客户端,如果对应目录下生成扰动样本图片则说明serving服务正确执行。
### 其他
在`serving_logs`目录下可以查看运行日志,辅助debug。

+ 41
- 0
examples/natural_robustness/serving/client/perturb_config.py View File

@@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Configuration of natural robustness methods for server.
"""
perturb_configs = [{"method": "Contrast", "params": {"alpha": 1.5, "beta": 0}},
{"method": "GaussianBlur", "params": {"ksize": 5}},
{"method": "SaltAndPepperNoise", "params": {"factor": 0.05}},
{"method": "Translate", "params": {"x_bias": 0.1, "y_bias": -0.2}},
{"method": "Scale", "params": {"factor_x": 0.7, "factor_y": 0.7}},
{"method": "Shear", "params": {"factor": 2, "direction": "horizontal"}},
{"method": "Rotate", "params": {"angle": 40}},
{"method": "MotionBlur", "params": {"degree": 5, "angle": 45}},
{"method": "GradientBlur", "params": {"point": [50, 100], "kernel_num": 3, "center": True}},
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0],
"start_point": [100, 150], "scope": 0.3,
"bright_rate": 0.3, "pattern": "light",
"mode": "circle"}},
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255],
"color_end": [0, 0, 0], "start_point": [150, 200],
"scope": 0.3, "pattern": "light", "mode": "horizontal"}},
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0],
"start_point": [150, 200], "scope": 0.3,
"pattern": "light", "mode": "vertical"}},
{"method": "Curve", "params": {"curves": 10, "depth": 10, "mode": "vertical"}},
{"method": "Perspective", "params": {"ori_pos": [[0, 0], [0, 800], [800, 0], [800, 800]],
"dst_pos": [[50, 0], [0, 800], [780, 0], [800, 800]]}},
]

+ 61
- 0
examples/natural_robustness/serving/client/serving_client.py View File

@@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The client of example add."""
import os
import json
from io import BytesIO

import cv2
from PIL import Image
from mindspore_serving.client import Client

from perturb_config import perturb_configs


def perturb(perturb_config):
"""Invoke servable perturbation method natural_perturbation"""
client = Client("0.0.0.0:5500", "perturbation", "natural_perturbation")
instances = []
img_path = 'test_data/1.png'
result_path = 'result/'
if not os.path.exists(result_path):
os.mkdir(result_path)
methods_number = 2
outputs_number = 10
img = cv2.imread(img_path)
img = cv2.imencode('.png', img)[1].tobytes()
perturb_config = json.dumps(perturb_config)
instances.append({"img": img, 'perturb_config': perturb_config, "methods_number": methods_number,
"outputs_number": outputs_number})

result = client.infer(instances)

file_names = result[0]['file_names'].split(';')
length = result[0]['file_length'].tolist()
before = 0
for name, leng in zip(file_names, length):
res_img = result[0]['results']
res_img = res_img[before:before + leng]
before = before + leng
print('name: ', name)
image = Image.open(BytesIO(res_img))
image.save(os.path.join(result_path, name))
names_dict = result[0]['names_dict']
with open('names_dict.json', 'w') as file:
file.write(names_dict)


if __name__ == '__main__':
perturb(perturb_configs)

+ 58
- 0
examples/natural_robustness/serving/server/export_model/add_model.py View File

@@ -0,0 +1,58 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""add model generator"""
import os
from shutil import copyfile
import numpy as np

import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore as ms

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Net(nn.Cell):
"""Define Net of add"""

def __init__(self):
super(Net, self).__init__()
self.add = ops.Add()

def construct(self, x_, y_):
"""construct add net"""
return self.add(x_, y_)


def export_net():
"""Export add net of 2x2 + 2x2, and copy output model `tensor_add.mindir` to directory ../add/1"""
x = np.ones([2, 2]).astype(np.float32)
y = np.ones([2, 2]).astype(np.float32)
add = Net()
ms.export(add, ms.Tensor(x), ms.Tensor(y), file_name='tensor_add', file_format='MINDIR')
dst_dir = '../perturbation/1'
try:
os.mkdir(dst_dir)
except OSError:
pass

dst_file = os.path.join(dst_dir, 'tensor_add.mindir')
copyfile('tensor_add.mindir', dst_file)
print("copy tensor_add.mindir to " + dst_dir + " success")


if __name__ == "__main__":
export_net()

+ 109
- 0
examples/natural_robustness/serving/server/perturbation/servable_config.py View File

@@ -0,0 +1,109 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""perturbation servable config"""
import json
import copy
import random
from io import BytesIO
import cv2
import numpy as np
from PIL import Image
from mindspore_serving.server import register
from mindarmour.natural_robustness import Contrast, GaussianBlur, SaltAndPepperNoise, Scale, Shear, \
Translate, Rotate, MotionBlur, GradientBlur, GradientLuminance, NaturalNoise, Curve, Perspective


CHARACTERS = [chr(i) for i in range(65, 91)]+[chr(j) for j in range(97, 123)]

methods_dict = {'Contrast': Contrast,
'GaussianBlur': GaussianBlur,
'SaltAndPepperNoise': SaltAndPepperNoise,
'Translate': Translate,
'Scale': Scale,
'Shear': Shear,
'Rotate': Rotate,
'MotionBlur': MotionBlur,
'GradientBlur': GradientBlur,
'GradientLuminance': GradientLuminance,
'NaturalNoise': NaturalNoise,
'Curve': Curve,
'Perspective': Perspective}


def check_inputs(img, perturb_config, methods_number, outputs_number):
"""Check inputs."""
if not np.any(img):
raise ValueError("img cannot be empty.")
img = Image.open(BytesIO(img))
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)

config = json.loads(perturb_config)
if not config:
raise ValueError("perturb_config cannot be empty.")
for item in config:
if item['method'] not in methods_dict.keys():
raise ValueError("{} is not a valid method.".format(item['method']))

methods_number = int(methods_number)
if methods_number < 1:
raise ValueError("methods_number must more than 0.")
outputs_number = int(outputs_number)
if outputs_number < 1:
raise ValueError("outputs_number must more than 0.")

return img, config, methods_number, outputs_number


def perturb(img, perturb_config, methods_number, outputs_number):
"""Perturb given image."""
img, config, methods_number, outputs_number = check_inputs(img, perturb_config, methods_number, outputs_number)
res_img_bytes = b''
file_names = []
file_length = []
names_dict = {}
for _ in range(outputs_number):
dst = copy.deepcopy(img)
used_methods = []
for _ in range(methods_number):
item = np.random.choice(config)
method_name = item['method']
method = methods_dict[method_name]
params = item['params']
dst = method(**params)(img)
method_params = params

used_methods.append([method_name, method_params])
name = ''.join(random.sample(CHARACTERS, 20))
name += '.png'
file_names.append(name)
names_dict[name] = used_methods

res_img = cv2.imencode('.png', dst)[1].tobytes()
res_img_bytes += res_img
file_length.append(len(res_img))

names_dict = json.dumps(names_dict)

return res_img_bytes, ';'.join(file_names), file_length, names_dict


model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False)


@register.register_method(output_names=["results", "file_names", "file_length", "names_dict"])
def natural_perturbation(img, perturb_config, methods_number, outputs_number):
"""method natural_perturbation data flow definition, only preprocessing and call model"""
res = register.add_stage(perturb, img, perturb_config, methods_number, outputs_number, outputs_count=4)
return res

+ 35
- 0
examples/natural_robustness/serving/server/serving_server.py View File

@@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The server of example perturbation"""

import os
import sys
from mindspore_serving import server


def start():
"""Start server."""
servable_dir = os.path.dirname(os.path.realpath(sys.argv[0]))

servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="perturbation",
device_ids=(0, 1), num_parallel_workers=4)
server.start_servables(servable_configs=servable_config)

server.start_grpc_server(address="0.0.0.0:5500", max_msg_mb_size=200)
# server.start_restful_server(address="0.0.0.0:5500")


if __name__ == "__main__":
start()

+ 4
- 4
mindarmour/adv_robustness/attacks/iterative_gradient_method.py View File

@@ -92,7 +92,7 @@ def _projection(values, eps, norm_level):
return proj_flat.reshape(values.shape)
if norm_level in (2, '2'):
return eps*normalize_value(values, norm_level)
if norm_level in (np.inf, 'inf'):
if norm_level in (np.inf, 'inf', 'linf', 'np.inf'):
return eps*np.sign(values)
msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \
'currently not supported.'
@@ -277,7 +277,7 @@ class MomentumIterativeMethod(IterativeGradientMethod):
nb_iter (int): Number of iteration. Default: 5.
decay_factor (float): Decay factor in iterations. Default: 1.0.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'inf'.
1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', np.inf and 'linf'. Default: 'inf'.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.

@@ -423,7 +423,7 @@ class ProjectedGradientDescent(BasicIterativeMethod):
attack. Default: False.
nb_iter (int): Number of iteration. Default: 5.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'inf'.
1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', np.inf and 'linf'. Default: 'inf'.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.

@@ -577,7 +577,7 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod):
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'l1'.
1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', np.inf and 'linf'. Default: 'l1'.
prob (float): Transformation probability. Default: 0.5.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.


+ 111
- 30
mindarmour/fuzz_testing/fuzzing.py View File

@@ -24,10 +24,10 @@ from mindspore import nn
from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \
check_param_in_range, check_param_type, check_int_positive, check_param_bounds
from mindarmour.utils.logger import LogUtil
from ..adv_robustness.attacks import FastGradientSignMethod, \
from mindarmour.adv_robustness.attacks import FastGradientSignMethod, \
MomentumDiverseInputIterativeMethod, ProjectedGradientDescent
from .image_transform import Contrast, Brightness, Blur, \
Noise, Translate, Scale, Shear, Rotate
from mindarmour.natural_robustness import GaussianBlur, MotionBlur, GradientBlur, UniformNoise, GaussianNoise, \
SaltAndPepperNoise, NaturalNoise, Contrast, GradientLuminance, Translate, Scale, Shear, Rotate, Perspective, Curve
from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage

LOGGER = LogUtil.get_instance()
@@ -104,17 +104,79 @@ class Fuzzer:
target_model (Model): Target fuzz model.

Examples:
>>> import numpy as np
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore.common.initializer import TruncatedNormal
>>> from mindspore.ops import operations as P
>>> from mindspore.train import Model
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import Fuzzer
>>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage
>>>
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
>>> self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
>>> self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02))
>>> self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02))
>>> self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02))
>>> self.relu = nn.ReLU()
>>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
>>> self.reshape = P.Reshape()
>>> self.summary = TensorSummary()
>>>
>>> def construct(self, x):
>>> x = self.conv1(x)
>>> x = self.relu(x)
>>> self.summary('conv1', x)
>>> x = self.max_pool2d(x)
>>> x = self.conv2(x)
>>> x = self.relu(x)
>>> self.summary('conv2', x)
>>> x = self.max_pool2d(x)
>>> x = self.reshape(x, (-1, 16 * 5 * 5))
>>> x = self.fc1(x)
>>> x = self.relu(x)
>>> self.summary('fc1', x)
>>> x = self.fc2(x)
>>> x = self.relu(x)
>>> self.summary('fc2', x)
>>> x = self.fc3(x)
>>> self.summary('fc3', x)
>>> return x
>>>
>>> net = Net()
>>> model = Model(net)
>>> mutate_config = [{'method': 'Blur',
... 'params': {'auto_param': [True]}},
>>> mutate_config = [{'method': 'GaussianBlur',
... 'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}},
... {'method': 'MotionBlur',
... 'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300],
... 'auto_param': [True]}},
... {'method': 'UniformNoise',
... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
... {'method': 'GaussianNoise',
... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
... {'method': 'Contrast',
... 'params': {'factor': [2]}},
... {'method': 'Translate',
... 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}},
... 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}},
... {'method': 'Rotate',
... 'params': {'angle': [20, 90], 'auto_param': [False, True]}},
... {'method': 'FGSM',
... 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}]
>>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100)
... 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}]
>>> batch_size = 8
>>> num_classe = 10
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32)
>>> test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32)
>>> initial_seeds = []
>>> # make initial seeds
>>> for img, label in zip(test_images, test_labels):
>>> initial_seeds.append([img, label])

>>> initial_seeds = initial_seeds[:10]
>>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100, incremental=True)
>>> model_fuzz_test = Fuzzer(model)
>>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds,
... nc, max_iters=100)
@@ -125,18 +187,26 @@ class Fuzzer:

# Allowed mutate strategies so far.
self._strategies = {'Contrast': Contrast,
'Brightness': Brightness,
'Blur': Blur,
'Noise': Noise,
'GradientLuminance': GradientLuminance,
'GaussianBlur': GaussianBlur,
'MotionBlur': MotionBlur,
'GradientBlur': GradientBlur,
'UniformNoise': UniformNoise,
'GaussianNoise': GaussianNoise,
'SaltAndPepperNoise': SaltAndPepperNoise,
'NaturalNoise': NaturalNoise,
'Translate': Translate,
'Scale': Scale,
'Shear': Shear,
'Rotate': Rotate,
'Perspective': Perspective,
'Curve': Curve,
'FGSM': FastGradientSignMethod,
'PGD': ProjectedGradientDescent,
'MDIIM': MomentumDiverseInputIterativeMethod}
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate']
self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', 'Noise']
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate', 'Perspective', 'Curve']
self._pixel_value_trans_list = ['Contrast', 'GradientLuminance', 'GaussianBlur', 'MotionBlur', 'GradientBlur',
'UniformNoise', 'GaussianNoise', 'SaltAndPepperNoise', 'NaturalNoise']
self._attacks_list = ['FGSM', 'PGD', 'MDIIM']
self._attack_param_checklists = {
'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]},
@@ -144,10 +214,11 @@ class Fuzzer:
'bounds': {'dtype': [tuple, list]}},
'PGD': {'eps': {'dtype': [float], 'range': [0, 1]},
'eps_iter': {'dtype': [float], 'range': [0, 1]},
'nb_iter': {'dtype': [int], 'range': [0, 100000]},
'nb_iter': {'dtype': [int]},
'bounds': {'dtype': [tuple, list]}},
'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]},
'norm_level': {'dtype': [str, int], 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf']},
'norm_level': {'dtype': [str, int],
'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf']},
'prob': {'dtype': [float], 'range': [0, 1]},
'bounds': {'dtype': [tuple, list]}}}

@@ -157,18 +228,26 @@ class Fuzzer:

Args:
mutate_config (list): Mutate configs. The format is
[{'method': 'Blur',
'params': {'radius': [0.1, 0.2], 'auto_param': [True, False]}},
{'method': 'Contrast',
'params': {'factor': [1, 1.5, 2]}},
{'method': 'FGSM',
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}},
...].
[{'method': 'GaussianBlur',
'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}},
{'method': 'UniformNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'GaussianNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'Contrast',
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}},
{'method': 'Rotate',
'params': {'angle': [20, 90], 'auto_param': [False, True]}},
{'method': 'FGSM',
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}]
...].
The supported methods list is in `self._strategies`, and the params of each method must within the
range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based
transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform
methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM',
'PGD' and 'MDIIM'. `mutate_config` must have method in the type of pixel value based transform methods.
'PGD' and 'MDIIM'. 'FGSM', 'PGD' and 'MDIIM'. are abbreviations of FastGradientSignMethod,
ProjectedGradientDescent and MomentumDiverseInputIterativeMethod.
`mutate_config` must have method in the type of pixel value based transform methods.
The way of setting parameters for first and second type methods can be seen in
'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to
`self._attack_param_checklists`.
@@ -278,7 +357,6 @@ class Fuzzer:
if only_pixel_trans:
while strategy['method'] not in self._pixel_value_trans_list:
strategy = choice(mutate_config)
transform = mutates[strategy['method']]
params = strategy['params']
method = strategy['method']
selected_param = {}
@@ -290,9 +368,10 @@ class Fuzzer:
shear_keys = selected_param.keys()
if 'factor_x' in shear_keys and 'factor_y' in shear_keys:
selected_param[choice(['factor_x', 'factor_y'])] = 0
transform.set_params(**selected_param)
mutate_sample = transform.transform(seed[0])
transform = mutates[strategy['method']](**selected_param)
mutate_sample = transform(seed[0])
else:
transform = mutates[strategy['method']]
for param_name in selected_param:
transform.__setattr__('_' + str(param_name), selected_param[param_name])
mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0]
@@ -360,6 +439,8 @@ class Fuzzer:
_ = check_param_bounds('bounds', param_value)
elif param_name == 'norm_level':
_ = check_norm_level(param_value)
elif param_name == 'nb_iter':
_ = check_int_positive(param_name, param_value)
else:
allow_type = self._attack_param_checklists[method][param_name]['dtype']
allow_range = self._attack_param_checklists[method][param_name]['range']
@@ -372,7 +453,8 @@ class Fuzzer:
for mutate in mutate_config:
method = mutate['method']
if method not in self._attacks_list:
mutates[method] = self._strategies[method]()
# mutates[method] = self._strategies[method]()
mutates[method] = self._strategies[method]
else:
network = self._target_model._network
loss_fn = self._target_model._loss_fn
@@ -414,7 +496,6 @@ class Fuzzer:
else:
attack_success_rate = None
metrics_report['Attack_success_rate'] = attack_success_rate

metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples)

return metrics_report

+ 0
- 609
mindarmour/fuzz_testing/image_transform.py View File

@@ -1,609 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image transform
"""
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter

from mindspore.dataset.vision.py_transforms_util import is_numpy, \
to_pil, hwc_to_chw
from mindarmour.utils._check_param import check_param_multi_types, check_param_in_range, check_numpy_param
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image Transformation'


def chw_to_hwc(img):
"""
Transpose the input image; shape (C, H, W) to shape (H, W, C).

Args:
img (numpy.ndarray): Image to be converted.

Returns:
img (numpy.ndarray), Converted image.
"""
if is_numpy(img):
return img.transpose(1, 2, 0).copy()
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def is_hwc(img):
"""
Check if the input image is shape (H, W, C).

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3:
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def is_chw(img):
"""
Check if the input image is shape (H, W, C).

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3:
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def is_rgb(img):
"""
Check if the input image is RGB.

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is RGB.
"""
if is_numpy(img):
img_shape = np.shape(img)
if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3):
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def is_normalized(img):
"""
Check if the input image is normalized between 0 to 1.

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is normalized between 0 to 1.
"""
if is_numpy(img):
minimal = np.min(img)
maximun = np.max(img)
if minimal >= 0 and maximun <= 1:
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))


class ImageTransform:
"""
The abstract base class for all image transform classes.
"""

def __init__(self):
pass

def _check(self, image):
""" Check image format. If input image is RGB and its shape
is (C, H, W), it will be transposed to (H, W, C). If the value
of the image is not normalized , it will be normalized between 0 to 1."""
rgb = is_rgb(image)
chw = False
gray3dim = False
normalized = is_normalized(image)
if rgb:
chw = is_chw(image)
if chw:
image = chw_to_hwc(image)
else:
image = image
else:
if len(np.shape(image)) == 3:
gray3dim = True
image = image[0]
else:
image = image
if normalized:
image = image*255
return rgb, chw, normalized, gray3dim, np.uint8(image)

def _original_format(self, image, chw, normalized, gray3dim):
""" Return transformed image with original format. """
if not is_numpy(image):
image = np.array(image)
if chw:
image = hwc_to_chw(image)
if normalized:
image = image / 255
if gray3dim:
image = np.expand_dims(image, 0)
return image

def transform(self, image):
pass


class Contrast(ImageTransform):
"""
Contrast of an image.

Args:
factor (Union[float, int]): Control the contrast of an image. If 1.0,
gives the original image. If 0, gives a gray image. Default: 1.
"""

def __init__(self, factor=1):
super(Contrast, self).__init__()
self.set_params(factor)

def set_params(self, factor=1, auto_param=False):
"""
Set contrast parameters.

Args:
factor (Union[float, int]): Control the contrast of an image. If 1.0
gives the original image. If 0 gives a gray image. Default: 1.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor = np.random.uniform(-5, 5)
else:
self.factor = check_param_multi_types('factor', factor, [int, float])

def transform(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
image = to_pil(image)
img_contrast = ImageEnhance.Contrast(image)
trans_image = img_contrast.enhance(self.factor)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)

return trans_image.astype(ori_dtype)


class Brightness(ImageTransform):
"""
Brightness of an image.

Args:
factor (Union[float, int]): Control the brightness of an image. If 1.0
gives the original image. If 0 gives a black image. Default: 1.
"""

def __init__(self, factor=1):
super(Brightness, self).__init__()
self.set_params(factor)

def set_params(self, factor=1, auto_param=False):
"""
Set brightness parameters.

Args:
factor (Union[float, int]): Control the brightness of an image. If 1
gives the original image. If 0 gives a black image. Default: 1.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor = np.random.uniform(0, 5)
else:
self.factor = check_param_multi_types('factor', factor, [int, float])

def transform(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
image = to_pil(image)
img_contrast = ImageEnhance.Brightness(image)
trans_image = img_contrast.enhance(self.factor)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Blur(ImageTransform):
"""
Blurs the image using Gaussian blur filter.

Args:
radius(Union[float, int]): Blur radius, 0 means no blur. Default: 0.
"""

def __init__(self, radius=0):
super(Blur, self).__init__()
self.set_params(radius)

def set_params(self, radius=0, auto_param=False):
"""
Set blur parameters.

Args:
radius (Union[float, int]): Blur radius, 0 means no blur. Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.radius = np.random.uniform(-1.5, 1.5)
else:
self.radius = check_param_multi_types('radius', radius, [int, float])

def transform(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
image = to_pil(image)
trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius))
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Noise(ImageTransform):
"""
Add noise of an image.

Args:
factor (float): factor is the ratio of pixels to add noise.
If 0 gives the original image. Default 0.
"""

def __init__(self, factor=0):
super(Noise, self).__init__()
self.set_params(factor)

def set_params(self, factor=0, auto_param=False):
"""
Set noise parameters.

Args:
factor (Union[float, int]): factor is the ratio of pixels to
add noise. If 0 gives the original image. Default 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor = np.random.uniform(0, 1)
else:
self.factor = check_param_multi_types('factor', factor, [int, float])

def transform(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
noise = np.random.uniform(low=-1, high=1, size=np.shape(image))
trans_image = np.copy(image)
threshold = 1 - self.factor
trans_image[noise < -threshold] = 0
trans_image[noise > threshold] = 1
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Translate(ImageTransform):
"""
Translate an image.

Args:
x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length.
Default: 0.
y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide.
Default: 0.
"""

def __init__(self, x_bias=0, y_bias=0):
super(Translate, self).__init__()
self.set_params(x_bias, y_bias)

def set_params(self, x_bias=0, y_bias=0, auto_param=False):
"""
Set translate parameters.

Args:
x_bias (Union[float, int]): X-direction translation, and x_bias should be in range of (-1, 1). Default: 0.
y_bias (Union[float, int]): Y-direction translation, and y_bias should be in range of (-1, 1). Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
x_bias = check_param_in_range('x_bias', x_bias, -1, 1)
y_bias = check_param_in_range('y_bias', y_bias, -1, 1)
self.auto_param = auto_param
if auto_param:
self.x_bias = np.random.uniform(-0.3, 0.3)
self.y_bias = np.random.uniform(-0.3, 0.3)
else:
self.x_bias = check_param_multi_types('x_bias', x_bias,
[int, float])
self.y_bias = check_param_multi_types('y_bias', y_bias,
[int, float])

def transform(self, image):
"""
Transform the image.

Args:
image(numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
img = to_pil(image)
image_shape = np.shape(image)
self.x_bias = image_shape[1]*self.x_bias
self.y_bias = image_shape[0]*self.y_bias
trans_image = img.transform(img.size, Image.AFFINE,
(1, 0, self.x_bias, 0, 1, self.y_bias))
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Scale(ImageTransform):
"""
Scale an image in the middle.

Args:
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x.
Default: 1.
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y.
Default: 1.
"""

def __init__(self, factor_x=1, factor_y=1):
super(Scale, self).__init__()
self.set_params(factor_x, factor_y)

def set_params(self, factor_x=1, factor_y=1, auto_param=False):

"""
Set scale parameters.

Args:
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x.
Default: 1.
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y.
Default: 1.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor_x = np.random.uniform(0.7, 3)
self.factor_y = np.random.uniform(0.7, 3)
else:
self.factor_x = check_param_multi_types('factor_x', factor_x,
[int, float])
self.factor_y = check_param_multi_types('factor_y', factor_y,
[int, float])

def transform(self, image):
"""
Transform the image.

Args:
image(numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
rgb, chw, normalized, gray3dim, image = self._check(image)
if rgb:
h, w, _ = np.shape(image)
else:
h, w = np.shape(image)
move_x_centor = w / 2*(1 - self.factor_x)
move_y_centor = h / 2*(1 - self.factor_y)
img = to_pil(image)
trans_image = img.transform(img.size, Image.AFFINE,
(self.factor_x, 0, move_x_centor,
0, self.factor_y, move_y_centor))
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Shear(ImageTransform):
"""
Shear an image, for each pixel (x, y) in the sheared image, the new value is
taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then
the sheared image will be rescaled to fit original size.

Args:
factor_x (Union[float, int]): Shear factor of horizontal direction.
Default: 0.
factor_y (Union[float, int]): Shear factor of vertical direction.
Default: 0.

"""

def __init__(self, factor_x=0, factor_y=0):
super(Shear, self).__init__()
self.set_params(factor_x, factor_y)

def set_params(self, factor_x=0, factor_y=0, auto_param=False):
"""
Set shear parameters.

Args:
factor_x (Union[float, int]): Shear factor of horizontal direction.
Default: 0.
factor_y (Union[float, int]): Shear factor of vertical direction.
Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if factor_x != 0 and factor_y != 0:
msg = 'At least one of factor_x and factor_y is zero.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
if auto_param:
if np.random.uniform(-1, 1) > 0:
self.factor_x = np.random.uniform(-2, 2)
self.factor_y = 0
else:
self.factor_x = 0
self.factor_y = np.random.uniform(-2, 2)
else:
self.factor_x = check_param_multi_types('factor', factor_x,
[int, float])
self.factor_y = check_param_multi_types('factor', factor_y,
[int, float])

def transform(self, image):
"""
Transform the image.

Args:
image(numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
rgb, chw, normalized, gray3dim, image = self._check(image)
img = to_pil(image)
if rgb:
h, w, _ = np.shape(image)
else:
h, w = np.shape(image)
if self.factor_x != 0:
boarder_x = [0, -w, -self.factor_x*h, -w - self.factor_x*h]
min_x = min(boarder_x)
max_x = max(boarder_x)
scale = (max_x - min_x) / w
move_x_cen = (w - scale*w - scale*h*self.factor_x) / 2
move_y_cen = h*(1 - scale) / 2
else:
boarder_y = [0, -h, -self.factor_y*w, -h - self.factor_y*w]
min_y = min(boarder_y)
max_y = max(boarder_y)
scale = (max_y - min_y) / h
move_y_cen = (h - scale*h - scale*w*self.factor_y) / 2
move_x_cen = w*(1 - scale) / 2
trans_image = img.transform(img.size, Image.AFFINE,
(scale, scale*self.factor_x, move_x_cen,
scale*self.factor_y, scale, move_y_cen))
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Rotate(ImageTransform):
"""
Rotate an image of degrees counter clockwise around its center.

Args:
angle(Union[float, int]): Degrees counter clockwise. Default: 0.
"""

def __init__(self, angle=0):
super(Rotate, self).__init__()
self.set_params(angle)

def set_params(self, angle=0, auto_param=False):
"""
Set rotate parameters.

Args:
angle(Union[float, int]): Degrees counter clockwise. Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.angle = np.random.uniform(0, 360)
else:
self.angle = check_param_multi_types('angle', angle, [int, float])

def transform(self, image):
"""
Transform the image.

Args:
image(numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
img = to_pil(image)
trans_image = img.rotate(self.angle, expand=False)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)

+ 184
- 21
mindarmour/fuzz_testing/model_coverage_metrics.py View File

@@ -154,13 +154,48 @@ class NeuronCoverage(CoverageMetrics):
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.

Examples:
>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore.nn import Cell
>>> from mindspore.train import Model
>>> from mindspore import context
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import NeuronCoverage
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>> self.summary = TensorSummary()
>>>
>>> def construct(self, inputs):
>>> self.summary('input', inputs)
>>> out = self._relu(inputs)
>>> self.summary('1', out)
>>> return out
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>> # load network
>>> net = Net()
>>> model = Model(net)
>>>
>>> # initialize fuzz test with training dataset
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32)
>>>
>>> # fuzz test with original test data
>>> # get test data
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32)
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32)
>>>
>>> nc = NeuronCoverage(model, threshold=0.1)
>>> nc_metric = nc.get_metrics(test_data)
"""
def __init__(self, model, threshold=0.1, incremental=False, batch_size=32):
super(NeuronCoverage, self).__init__(model, incremental, batch_size)
threshold = check_param_type('threshold', threshold, float)
self.threshold = check_value_positive('threshold', threshold)


def get_metrics(self, dataset):
"""
Get the metric of neuron coverage: the proportion of activated neurons to total neurons in the network.
@@ -170,10 +205,6 @@ class NeuronCoverage(CoverageMetrics):

Returns:
float, the metric of 'neuron coverage'.

Examples:
>>> nc = NeuronCoverage(model, threshold=0.1)
>>> nc_metrics = nc.get_metrics(test_data)
"""
dataset = check_numpy_param('dataset', dataset)
batches = math.ceil(dataset.shape[0] / self.batch_size)
@@ -203,6 +234,43 @@ class TopKNeuronCoverage(CoverageMetrics):
top_k (int): Neuron is activated when its output has the top k largest value in that hidden layers. Default: 3.
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.

Examples:
>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore.nn import Cell
>>> from mindspore.train import Model
>>> from mindspore import context
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import TopKNeuronCoverage
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>> self.summary = TensorSummary()
>>>
>>> def construct(self, inputs):
>>> self.summary('input', inputs)
>>> out = self._relu(inputs)
>>> self.summary('1', out)
>>> return out
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>> # load network
>>> net = Net()
>>> model = Model(net)
>>>
>>> # initialize fuzz test with training dataset
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32)
>>>
>>> # fuzz test with original test data
>>> # get test data
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32)
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32)
>>>
>>> tknc = TopKNeuronCoverage(model, top_k=3)
>>> tknc_metrics = tknc.get_metrics(test_data)
"""
def __init__(self, model, top_k=3, incremental=False, batch_size=32):
super(TopKNeuronCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size)
@@ -217,10 +285,6 @@ class TopKNeuronCoverage(CoverageMetrics):

Returns:
float, the metrics of 'top k neuron coverage'.

Examples:
>>> tknc = TopKNeuronCoverage(model, top_k=3)
>>> metrics = tknc.get_metrics(test_data)
"""
dataset = check_numpy_param('dataset', dataset)
batches = math.ceil(dataset.shape[0] / self.batch_size)
@@ -252,6 +316,43 @@ class SuperNeuronActivateCoverage(CoverageMetrics):
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.

Examples:
>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore.nn import Cell
>>> from mindspore.train import Model
>>> from mindspore import context
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import SuperNeuronActivateCoverage
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>> self.summary = TensorSummary()
>>>
>>> def construct(self, inputs):
>>> self.summary('input', inputs)
>>> out = self._relu(inputs)
>>> self.summary('1', out)
>>> return out
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>> # load network
>>> net = Net()
>>> model = Model(net)
>>>
>>> # initialize fuzz test with training dataset
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32)
>>>
>>> # fuzz test with original test data
>>> # get test data
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32)
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32)
>>>
>>> snac = SuperNeuronActivateCoverage(model, training_data)
>>> snac_metrics = snac.get_metrics(test_data)
"""
def __init__(self, model, train_dataset, incremental=False, batch_size=32):
super(SuperNeuronActivateCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size)
@@ -267,10 +368,6 @@ class SuperNeuronActivateCoverage(CoverageMetrics):

Returns:
float, the metric of 'strong neuron activation coverage'.

Examples:
>>> snac = SuperNeuronActivateCoverage(model, train_dataset)
>>> metrics = snac.get_metrics(test_data)
"""
dataset = check_numpy_param('dataset', dataset)
if not self.incremental or not self._activate_table:
@@ -303,6 +400,43 @@ class NeuronBoundsCoverage(SuperNeuronActivateCoverage):
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.

Examples:
>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore.nn import Cell
>>> from mindspore.train import Model
>>> from mindspore import context
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import NeuronBoundsCoverage
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>> self.summary = TensorSummary()
>>>
>>> def construct(self, inputs):
>>> self.summary('input', inputs)
>>> out = self._relu(inputs)
>>> self.summary('1', out)
>>> return out
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>> # load network
>>> net = Net()
>>> model = Model(net)
>>>
>>> # initialize fuzz test with training dataset
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32)
>>>
>>> # fuzz test with original test data
>>> # get test data
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32)
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32)
>>>
>>> nbc = NeuronBoundsCoverage(model, training_data)
>>> nbc_metrics = nbc.get_metrics(test_data)
"""

def __init__(self, model, train_dataset, incremental=False, batch_size=32):
@@ -317,10 +451,6 @@ class NeuronBoundsCoverage(SuperNeuronActivateCoverage):

Returns:
float, the metric of 'neuron boundary coverage'.

Examples:
>>> nbc = NeuronBoundsCoverage(model, train_dataset)
>>> metrics = nbc.get_metrics(test_data)
"""
dataset = check_numpy_param('dataset', dataset)
if not self.incremental or not self._activate_table:
@@ -353,6 +483,43 @@ class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage):
segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 100.
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.

Examples:
>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore.nn import Cell
>>> from mindspore.train import Model
>>> from mindspore import context
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>> self.summary = TensorSummary()
>>>
>>> def construct(self, inputs):
>>> self.summary('input', inputs)
>>> out = self._relu(inputs)
>>> self.summary('1', out)
>>> return out
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>> # load network
>>> net = Net()
>>> model = Model(net)
>>>
>>> # initialize fuzz test with training dataset
>>> training_data = (np.random.random((10000, 10))*20).astype(np.float32)
>>>
>>> # fuzz test with original test data
>>> # get test data
>>> test_data = (np.random.random((2000, 10))*20).astype(np.float32)
>>> test_labels = np.random.randint(0, 10, 2000).astype(np.int32)
>>>
>>> kmnc = KMultisectionNeuronCoverage(model, training_data, segmented_num=100)
>>> kmnc_metrics = kmnc.get_metrics(test_data)
"""

def __init__(self, model, train_dataset, segmented_num=100, incremental=False, batch_size=32):
@@ -381,10 +548,6 @@ class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage):

Returns:
float, the metric of 'k-multisection neuron coverage'.

Examples:
>>> kmnc = KMultisectionNeuronCoverage(model, train_dataset, segmented_num=100)
>>> metrics = kmnc.get_metrics(test_data)
"""

dataset = check_numpy_param('dataset', dataset)


+ 37
- 0
mindarmour/natural_robustness/__init__.py View File

@@ -0,0 +1,37 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This package include methods to generate natural perturbation samples.
"""

from .transformation import Translate, Scale, Shear, Rotate, Perspective, Curve
from .blur import GaussianBlur, MotionBlur, GradientBlur
from .luminance import Contrast, GradientLuminance
from .corruption import UniformNoise, GaussianNoise, SaltAndPepperNoise, NaturalNoise

__all__ = ['Translate',
'Scale',
'Shear',
'Rotate',
'Perspective',
'Curve',
'GaussianBlur',
'MotionBlur',
'GradientBlur',
'Contrast',
'GradientLuminance',
'UniformNoise',
'GaussianNoise',
'SaltAndPepperNoise',
'NaturalNoise']

+ 193
- 0
mindarmour/natural_robustness/blur.py View File

@@ -0,0 +1,193 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image Blur
"""

import numpy as np
import cv2

from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb
from mindarmour.utils._check_param import check_param_multi_types, check_int_positive, check_param_type
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image Blur'


class GaussianBlur(_NaturalPerturb):
"""
Blurs the image using Gaussian blur filter.

Args:
ksize (int): Size of gaussian kernel, this value must be non-negnative.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> ksize = 5
>>> trans = GaussianBlur(ksize)
>>> dst = trans(img)
"""

def __init__(self, ksize=2, auto_param=False):
super(GaussianBlur, self).__init__()
ksize = check_int_positive('ksize', ksize)
if auto_param:
ksize = 2 * np.random.randint(0, 5) + 1
else:
ksize = 2 * ksize + 1
self.ksize = (ksize, ksize)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
new_img = cv2.GaussianBlur(image, self.ksize, 0)
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class MotionBlur(_NaturalPerturb):
"""
Motion blur for a given image.

Args:
degree (int): Degree of blur. This value must be positive. Suggested value range in [1, 15].
angle: (union[float, int]): Direction of motion blur. Angle=0 means up and down motion blur. Angle is
counterclockwise.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> angle = 0
>>> degree = 5
>>> trans = MotionBlur(degree=degree, angle=angle)
>>> new_img = trans(img)
"""

def __init__(self, degree=5, angle=45, auto_param=False):
super(MotionBlur, self).__init__()
self.degree = check_int_positive('degree', degree)
self.degree = check_param_multi_types('degree', degree, [float, int])
auto_param = check_param_type('auto_param', auto_param, bool)
if auto_param:
self.degree = np.random.randint(1, 5)
self.angle = np.random.uniform(0, 360)
else:
self.angle = angle - 45

def __call__(self, image):
"""
Motion blur for a given image.

Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, image after motion blur.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
matrix = cv2.getRotationMatrix2D((self.degree / 2, self.degree / 2), self.angle, 1)
motion_blur_kernel = np.diag(np.ones(self.degree))
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, matrix, (self.degree, self.degree))
motion_blur_kernel = motion_blur_kernel / self.degree
blurred = cv2.filter2D(image, -1, motion_blur_kernel)
# convert to uint8
cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX)
blurred = self._original_format(blurred, chw, normalized, gray3dim)

return blurred.astype(ori_dtype)


class GradientBlur(_NaturalPerturb):
"""
Gradient blur.

Args:
point (union[tuple, list]): 2D coordinate of the Blur center point.
kernel_num (int): Number of blur kernels. Suggested value range in [1, 8].
center (bool): Blurred or clear at the center of a specified point.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('xx.png')
>>> img = np.array(img)
>>> number = 5
>>> h, w = img.shape[:2]
>>> point = (int(h / 5), int(w / 5))
>>> center = True
>>> trans = GradientBlur(point, number, center)
>>> new_img = trans(img)
"""

def __init__(self, point, kernel_num=3, center=True, auto_param=False):
super(GradientBlur).__init__()
point = check_param_multi_types('point', point, [list, tuple])
self.auto_param = check_param_type('auto_param', auto_param, bool)
self.point = tuple(point)
self.kernel_num = check_int_positive('kernel_num', kernel_num)
self.center = check_param_type('center', center, bool)

def _auto_param(self, h, w):
self.point = (int(np.random.uniform(0, h)), int(np.random.uniform(0, w)))
self.kernel_num = np.random.randint(1, 6)
self.center = np.random.choice([True, False])

def __call__(self, image):
"""
Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, gradient blurred image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
w, h = image.shape[:2]
if self.auto_param:
self._auto_param(h, w)
mask = np.zeros(image.shape, dtype=np.uint8)
masks = []
radius = max(w - self.point[0], self.point[0], h - self.point[1], self.point[1])
radius = int(radius / self.kernel_num)
for i in range(self.kernel_num):
circle = cv2.circle(mask.copy(), self.point, radius * (1 + i), (1, 1, 1), -1)
masks.append(circle)
blurs = []
for i in range(3, 3 + 2 * self.kernel_num, 2):
ksize = (i, i)
blur = cv2.GaussianBlur(image, ksize, 0)
blurs.append(blur)

dst = image.copy()
if self.center:
for i in range(self.kernel_num):
dst = masks[i] * dst + (1 - masks[i]) * blurs[i]
else:
for i in range(self.kernel_num - 1, -1, -1):
dst = masks[i] * blurs[self.kernel_num - 1 - i] + (1 - masks[i]) * dst
dst = self._original_format(dst, chw, normalized, gray3dim)
return dst.astype(ori_dtype)

+ 251
- 0
mindarmour/natural_robustness/corruption.py View File

@@ -0,0 +1,251 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image corruption.
"""
import math
import numpy as np
import cv2

from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb
from mindarmour.utils._check_param import check_param_multi_types, check_param_type
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image corruption'


class UniformNoise(_NaturalPerturb):
"""
Add uniform noise of an image.

Args:
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in
[0.001, 0.15].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor = 0.1
>>> trans = UniformNoise(factor)
>>> dst = trans(img)
"""

def __init__(self, factor=0.1, auto_param=False):
super(UniformNoise, self).__init__()
self.factor = check_param_multi_types('factor', factor, [int, float])
check_param_type('auto_param', auto_param, bool)
if auto_param:
self.factor = np.random.uniform(0, 0.15)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
low, high = (0, 255)
weight = self.factor * (high - low)
noise = np.random.uniform(-weight, weight, size=image.shape)
trans_image = np.clip(image + noise, low, high)
trans_image = self._original_format(trans_image, chw, normalized, gray3dim)

return trans_image.astype(ori_dtype)


class GaussianNoise(_NaturalPerturb):
"""
Add gaussian noise of an image.

Args:
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in
[0.001, 0.15].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor = 0.1
>>> trans = GaussianNoise(factor)
>>> dst = trans(img)
"""

def __init__(self, factor=0.1, auto_param=False):
super(GaussianNoise, self).__init__()
self.factor = check_param_multi_types('factor', factor, [int, float])
check_param_type('auto_param', auto_param, bool)
if auto_param:
self.factor = np.random.uniform(0, 0.15)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
low, high = (0, 255)
_, chw, normalized, gray3dim, image = self._check(image)
std = self.factor / math.sqrt(3) * (high - low)
noise = np.random.normal(scale=std, size=image.shape)
trans_image = np.clip(image + noise, low, high)
trans_image = self._original_format(trans_image, chw, normalized, gray3dim)
return trans_image.astype(ori_dtype)


class SaltAndPepperNoise(_NaturalPerturb):
"""
Add salt and pepper noise of an image.

Args:
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in
[0.001, 0.15].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor = 0.1
>>> trans = SaltAndPepperNoise(factor)
>>> dst = trans(img)
"""

def __init__(self, factor=0, auto_param=False):
super(SaltAndPepperNoise, self).__init__()
self.factor = check_param_multi_types('factor', factor, [int, float])
check_param_type('auto_param', auto_param, bool)
if auto_param:
self.factor = np.random.uniform(0, 0.15)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
low, high = (0, 255)
noise = np.random.uniform(low=-1, high=1, size=(image.shape[0], image.shape[1]))
trans_image = np.copy(image)
threshold = 1 - self.factor
trans_image[noise < -threshold] = low
trans_image[noise > threshold] = high
trans_image = self._original_format(trans_image, chw, normalized, gray3dim)
return trans_image.astype(ori_dtype)


class NaturalNoise(_NaturalPerturb):
"""
Add natural noise to an image.

Args:
ratio (float): Noise density, the proportion of noise blocks per unit pixel area. Suggested value range in
[0.00001, 0.001].
k_x_range (union[list, tuple]): Value range of the noise block length.
k_y_range (union[list, tuple]): Value range of the noise block width.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Examples:
>>> img = cv2.imread('xx.png')
>>> img = np.array(img)
>>> ratio = 0.0002
>>> k_x_range = (1, 5)
>>> k_y_range = (3, 25)
>>> trans = NaturalNoise(ratio, k_x_range, k_y_range)
>>> new_img = trans(img)
"""

def __init__(self, ratio=0.0002, k_x_range=(1, 5), k_y_range=(3, 25), auto_param=False):
super(NaturalNoise).__init__()
self.ratio = check_param_type('ratio', ratio, float)
k_x_range = check_param_multi_types('k_x_range', k_x_range, [list, tuple])
k_y_range = check_param_multi_types('k_y_range', k_y_range, [list, tuple])
self.k_x_range = tuple(k_x_range)
self.k_y_range = tuple(k_y_range)
self.auto_param = check_param_type('auto_param', auto_param, bool)

def __call__(self, image):
"""
Add natural noise to given image.

Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, image with natural noise.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
randon_range = 100
w, h = image.shape[:2]
channel = len(np.shape(image))

if self.auto_param:
self.ratio = np.random.uniform(0, 0.001)
self.k_x_range = (1, 0.1 * w)
self.k_y_range = (1, 0.1 * h)

for _ in range(5):
if channel == 3:
noise = np.ones((w, h, 3), dtype=np.uint8) * 255
dst = np.ones((w, h, 3), dtype=np.uint8) * 255
else:
noise = np.ones((w, h), dtype=np.uint8) * 255
dst = np.ones((w, h), dtype=np.uint8) * 255

rate = self.ratio / 5
mask = np.random.uniform(size=(w, h)) < rate
noise[mask] = np.random.randint(0, randon_range)

k_x, k_y = np.random.randint(*self.k_x_range), np.random.randint(*self.k_y_range)
kernel = np.ones((k_x, k_y), np.uint8)
erode = cv2.erode(noise, kernel, iterations=1)
dst = erode * (erode < randon_range) + dst * (1 - erode < randon_range)
# Add black point
for _ in range(np.random.randint(math.ceil(k_x * k_y / 2))):
x = np.random.randint(-k_x, k_x)
y = np.random.randint(-k_y, k_y)
matrix = np.array([[1, 0, y], [0, 1, x]], dtype=np.float)
affine = cv2.warpAffine(noise, matrix, (h, w))
dst = affine * (affine < randon_range) + dst * (1 - affine < randon_range)
# Add white point
for _ in range(int(k_x * k_y / 2)):
x = np.random.randint(-k_x / 2 - 1, k_x / 2 + 1)
y = np.random.randint(-k_y / 2 - 1, k_y / 2 + 1)
matrix = np.array([[1, 0, y], [0, 1, x]], dtype=np.float)
affine = cv2.warpAffine(noise, matrix, (h, w))
white = affine < randon_range
dst[white] = 255

mask = dst < randon_range
dst = image * (1 - mask) + dst * mask
dst = self._original_format(dst, chw, normalized, gray3dim)

return dst.astype(ori_dtype)

+ 287
- 0
mindarmour/natural_robustness/luminance.py View File

@@ -0,0 +1,287 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image luminance.
"""
import math
import numpy as np
import cv2

from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb
from mindarmour.utils._check_param import check_param_multi_types, check_param_in_range, check_param_type, \
check_value_non_negative
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image Luminance'


class Contrast(_NaturalPerturb):
"""
Contrast of an image.

Args:
alpha (Union[float, int]): Control the contrast of an image. :math:`out_image = in_image*alpha+beta`.
Suggested value range in [0.2, 2].
beta (Union[float, int]): Delta added to alpha. Default: 0.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> alpha = 0.1
>>> beta = 1
>>> trans = Contrast(alpha, beta)
>>> dst = trans(img)
"""

def __init__(self, alpha=1, beta=0, auto_param=False):
super(Contrast, self).__init__()
self.alpha = check_param_multi_types('factor', alpha, [int, float])
self.beta = check_param_multi_types('factor', beta, [int, float])
auto_param = check_param_type('auto_param', auto_param, bool)
if auto_param:
self.alpha = np.random.uniform(0.2, 2)
self.beta = np.random.uniform(-20, 20)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
dst = cv2.convertScaleAbs(image, alpha=self.alpha, beta=self.beta)
dst = self._original_format(dst, chw, normalized, gray3dim)
return dst.astype(ori_dtype)


def _circle_gradient_mask(img_src, color_start, color_end, scope=0.5, point=None):
"""
Generate circle gradient mask.

Args:
img_src (numpy.ndarray): Source image.
color_start (union([tuple, list])): Color of circle gradient center.
color_end (union([tuple, list])): Color of circle gradient edge.
scope (float): Range of the gradient. A larger value indicates a larger gradient range.
point (union([tuple, list]): Gradient center point.

Returns:
numpy.ndarray, gradients mask.
"""
if not isinstance(img_src, np.ndarray):
raise TypeError('`src` must be numpy.ndarray type, but got {0}.'.format(type(img_src)))

shape = img_src.shape
height, width = shape[:2]
rgb = False
if len(shape) == 3:
rgb = True
if point is None:
point = (height // 2, width // 2)
x, y = point

# upper left
bound_upper_left = math.ceil(math.sqrt(x ** 2 + y ** 2))
# upper right
bound_upper_right = math.ceil(math.sqrt(height ** 2 + (width - y) ** 2))
# lower left
bound_lower_left = math.ceil(math.sqrt((height - x) ** 2 + y ** 2))
# lower right
bound_lower_right = math.ceil(math.sqrt((height - x) ** 2 + (width - y) ** 2))

radius = max(bound_lower_left, bound_lower_right, bound_upper_left, bound_upper_right) * scope

img_grad = np.ones_like(img_src, dtype=np.uint8) * max(color_end)
# opencv use BGR format
grad_b = float(color_end[0] - color_start[0]) / radius
grad_g = float(color_end[1] - color_start[1]) / radius
grad_r = float(color_end[2] - color_start[2]) / radius

for i in range(height):
for j in range(width):
distance = math.ceil(math.sqrt((x - i) ** 2 + (y - j) ** 2))
if distance >= radius:
continue
if rgb:
img_grad[i, j, 0] = color_start[0] + distance * grad_b
img_grad[i, j, 1] = color_start[1] + distance * grad_g
img_grad[i, j, 2] = color_start[2] + distance * grad_r
else:
img_grad[i, j] = color_start[0] + distance * grad_b

return img_grad.astype(np.uint8)


def _line_gradient_mask(image, start_pos=None, start_color=(0, 0, 0), end_color=(255, 255, 255), mode='horizontal'):
"""
Generate liner gradient mask.

Args:
image (numpy.ndarray): Original image.
start_pos (union[tuple, list]): 2D coordinate of gradient center.
start_color (union([tuple, list])): Color of circle gradient center.
end_color (union([tuple, list])): Color of circle gradient edge.
mode (str): Direction of gradient. Optional value is 'vertical' or 'horizontal'.

Returns:
numpy.ndarray, gradients mask.
"""
shape = image.shape
h, w = shape[:2]
rgb = False
if len(shape) == 3:
rgb = True
if start_pos is None:
start_pos = 0.5
else:
if mode == 'horizontal':
if start_pos[0] > h:
start_pos = 1
else:
start_pos = start_pos[0] / h
else:
if start_pos[1] > w:
start_pos = 1
else:
start_pos = start_pos[1] / w
start_color = np.array(start_color)
end_color = np.array(end_color)
if mode == 'horizontal':
w_l = int(w * start_pos)
w_r = w - w_l
if w_l > w_r:
r_end_color = (end_color - start_color) / start_pos * (1 - start_pos) + start_color
left = np.linspace(end_color, start_color, w_l)
right = np.linspace(start_color, r_end_color, w_r)
else:
l_end_color = (end_color - start_color) / (1 - start_pos) * start_pos + start_color
left = np.linspace(l_end_color, start_color, w_l)
right = np.linspace(start_color, end_color, w_r)
line = np.concatenate((left, right), axis=0)
mask = np.reshape(np.tile(line, (h, 1)), (h, w, 3))
else:
# 'vertical'
h_t = int(h * start_pos)
h_b = h - h_t
if h_t > h_b:
b_end_color = (end_color - start_color) / start_pos * (1 - start_pos) + start_color
top = np.linspace(end_color, start_color, h_t)
bottom = np.linspace(start_color, b_end_color, h_b)
else:
t_end_color = (end_color - start_color) / (1 - start_pos) * start_pos + start_color
top = np.linspace(t_end_color, start_color, h_t)
bottom = np.linspace(start_color, end_color, h_b)
line = np.concatenate((top, bottom), axis=0)
mask = np.reshape(np.tile(line, (w, 1)), (w, h, 3))
mask = np.transpose(mask, [1, 0, 2])
if not rgb:
mask = mask[:, :, 0]
return mask.astype(np.uint8)


class GradientLuminance(_NaturalPerturb):
"""
Gradient adjusts the luminance of picture.

Args:
color_start (union[tuple, list]): Color of gradient center. Default:(0, 0, 0).
color_end (union[tuple, list]): Color of gradient edge. Default:(255, 255, 255).
start_point (union[tuple, list]): 2D coordinate of gradient center.
scope (float): Range of the gradient. A larger value indicates a larger gradient range. Default: 0.3.
pattern (str): Dark or light, this value must be in ['light', 'dark'].
bright_rate (float): Control brightness of . A larger value indicates a larger gradient range. If parameter
'pattern' is 'light', Suggested value range in [0.1, 0.7], if parameter 'pattern' is 'dark', Suggested value
range in [0.1, 0.9].
mode (str): Gradient mode, value must be in ['circle', 'horizontal', 'vertical'].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Examples:
>>> img = cv2.imread('x.png')
>>> height, width = img.shape[:2]
>>> point = (height // 4, width // 2)
>>> start = (255, 255, 255)
>>> end = (0, 0, 0)
>>> scope = 0.3
>>> pattern='light'
>>> bright_rate = 0.3
>>> trans = GradientLuminance(start, end, point, scope, pattern, bright_rate, mode='circle')
>>> img_new = trans(img)
"""

def __init__(self, color_start=(0, 0, 0), color_end=(255, 255, 255), start_point=(10, 10), scope=0.5,
pattern='light', bright_rate=0.3, mode='circle', auto_param=False):
super(GradientLuminance, self).__init__()
self.color_start = check_param_multi_types('color_start', color_start, [list, tuple])
self.color_end = check_param_multi_types('color_end', color_end, [list, tuple])
self.start_point = check_param_multi_types('start_point', start_point, [list, tuple])
self.scope = check_value_non_negative('scope', scope)
self.bright_rate = check_param_type('bright_rate', bright_rate, float)
self.bright_rate = check_param_in_range('bright_rate', bright_rate, 0, 1)
self.auto_param = check_param_type('auto_param', auto_param, bool)

if pattern in ['light', 'dark']:
self.pattern = pattern
else:
msg = "Value of param pattern must be in ['light', 'dark']"
LOGGER.error(TAG, msg)
raise ValueError(msg)
if mode in ['circle', 'horizontal', 'vertical']:
self.mode = mode
else:
msg = "Value of param mode must be in ['circle', 'horizontal', 'vertical']"
LOGGER.error(TAG, msg)
raise ValueError(msg)

def _set_auto_param(self, w, h):
self.color_start = (np.random.uniform(0, 255),) * 3
self.color_end = (np.random.uniform(0, 255),) * 3
self.start_point = (np.random.uniform(0, w), np.random.uniform(0, h))
self.scope = np.random.uniform(0, 1)
self.bright_rate = np.random.uniform(0.1, 0.9)
self.pattern = np.random.choice(['light', 'dark'])
self.mode = np.random.choice(['circle', 'horizontal', 'vertical'])

def __call__(self, image):
"""
Gradient adjusts the luminance of picture.

Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, image with perlin noise.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
w, h = image.shape[:2]
if self.auto_param:
self._set_auto_param(w, h)
if self.mode == 'circle':
mask = _circle_gradient_mask(image, self.color_start, self.color_end, self.scope, self.start_point)
else:
mask = _line_gradient_mask(image, self.start_point, self.color_start, self.color_end, mode=self.mode)

if self.pattern == 'light':
img_new = cv2.addWeighted(image, 1, mask, self.bright_rate, 0.0)
else:
img_new = cv2.addWeighted(image, self.bright_rate, mask, 1 - self.bright_rate, 0.0)
img_new = self._original_format(img_new, chw, normalized, gray3dim)
return img_new.astype(ori_dtype)

+ 159
- 0
mindarmour/natural_robustness/natural_perturb.py View File

@@ -0,0 +1,159 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for image natural perturbation.
"""
import numpy as np

from mindspore.dataset.vision.py_transforms_util import is_numpy, hwc_to_chw
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image Transformation'


def _chw_to_hwc(img):
"""
Transpose the input image; shape (C, H, W) to shape (H, W, C).

Args:
img (numpy.ndarray): Image to be converted.

Returns:
img (numpy.ndarray), Converted image.
"""
if is_numpy(img):
return img.transpose(1, 2, 0).copy()
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def _is_hwc(img):
"""
Check if the input image is shape (H, W, C).

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3:
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def _is_chw(img):
"""
Check if the input image is shape (H, W, C).

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3:
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def _is_rgb(img):
"""
Check if the input image is RGB.

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is RGB.
"""
if is_numpy(img):
img_shape = np.shape(img)
if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3):
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def _is_normalized(img):
"""
Check if the input image is normalized between 0 to 1.

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is normalized between 0 to 1.
"""
if is_numpy(img):
minimal = np.min(img)
maximum = np.max(img)
if minimal >= 0 and maximum <= 1:
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))


class _NaturalPerturb:
"""
The abstract base class for all image natural perturbation classes.
"""

def __init__(self):
pass

def _check(self, image):
""" Check image format. If input image is RGB and its shape
is (C, H, W), it will be transposed to (H, W, C). If the value
of the image is not normalized , it will be rescaled between 0 to 255."""
rgb = _is_rgb(image)
chw = False
gray3dim = False
normalized = _is_normalized(image)
if rgb:
chw = _is_chw(image)
if chw:
image = _chw_to_hwc(image)
else:
image = image
else:
if len(np.shape(image)) == 3:
gray3dim = True
image = image[0]
else:
image = image
if normalized:
image = image * 255
return rgb, chw, normalized, gray3dim, np.uint8(image)

def _original_format(self, image, chw, normalized, gray3dim):
""" Return image with original format. """
if not is_numpy(image):
image = np.array(image)
if chw:
image = hwc_to_chw(image)
if normalized:
image = image / 255
if gray3dim:
image = np.expand_dims(image, 0)
return image

def __call__(self, image):
pass

+ 365
- 0
mindarmour/natural_robustness/transformation.py View File

@@ -0,0 +1,365 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image transformation.
"""
import math
import numpy as np
import cv2

from mindarmour.natural_robustness.natural_perturb import _NaturalPerturb
from mindarmour.utils._check_param import check_param_multi_types, check_param_type, check_value_non_negative
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image Transformation'


class Translate(_NaturalPerturb):
"""
Translate an image.

Args:
x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length. Suggested value range
in [-0.1, 0.1].
y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide. Suggested value range
in [-0.1, 0.1].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> x_bias = 0.1
>>> y_bias = 0.1
>>> trans = Translate(x_bias, y_bias)
>>> dst = trans(img)
"""

def __init__(self, x_bias=0, y_bias=0, auto_param=False):
super(Translate, self).__init__()
self.x_bias = check_param_multi_types('x_bias', x_bias, [int, float])
self.y_bias = check_param_multi_types('y_bias', y_bias, [int, float])
if auto_param:
self.x_bias = np.random.uniform(-0.1, 0.1)
self.y_bias = np.random.uniform(-0.1, 0.1)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
matrix = np.array([[1, 0, self.x_bias * w], [0, 1, self.y_bias * h]], dtype=np.float)
new_img = cv2.warpAffine(image, matrix, (w, h))
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Scale(_NaturalPerturb):
"""
Scale an image in the middle.

Args:
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. Suggested value range in [0.5, 1] and
abs(factor_y - factor_x) < 0.5.
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. Suggested value range in [0.5, 1] and
abs(factor_y - factor_x) < 0.5.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor_x = 0.7
>>> factor_y = 0.6
>>> trans = Scale(factor_x, factor_y)
>>> dst = trans(img)
"""

def __init__(self, factor_x=1, factor_y=1, auto_param=False):
super(Scale, self).__init__()
self.factor_x = check_param_multi_types('factor_x', factor_x, [int, float])
self.factor_y = check_param_multi_types('factor_y', factor_y, [int, float])
auto_param = check_param_type('auto_param', auto_param, bool)
if auto_param:
self.factor_x = np.random.uniform(0.5, 1)
self.factor_y = np.random.uniform(0.5, 1)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
matrix = np.array([[self.factor_x, 0, 0], [0, self.factor_y, 0]], dtype=np.float)
new_img = cv2.warpAffine(image, matrix, (w, h))
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Shear(_NaturalPerturb):
"""
Shear an image, for each pixel (x, y) in the sheared image, the new value is taken from a position
(x+factor_x*y, factor_y*x+y) in the origin image. Then the sheared image will be rescaled to fit original size.

Args:
factor (Union[float, int]): Shear rate in shear direction. Suggested value range in [0.05, 0.5].
direction (str): Direction of deformation. Optional value is 'vertical' or 'horizontal'.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor = 0.2
>>> trans = Shear(factor, direction='horizontal')
>>> dst = trans(img)
"""

def __init__(self, factor=0.2, direction='horizontal', auto_param=False):
super(Shear, self).__init__()
self.factor = check_param_multi_types('factor', factor, [int, float])
if direction not in ['horizontal', 'vertical']:
msg = "'direction must be in ['horizontal', 'vertical'], but got {}".format(direction)
raise ValueError(msg)
self.direction = direction
auto_param = check_param_type('auto_params', auto_param, bool)
if auto_param:
self.factor = np.random.uniform(0.05, 0.5)
self.direction = np.random.choice(['horizontal', 'vertical'])

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
if self.direction == 'horizontal':
matrix = np.array([[1, self.factor, 0], [0, 1, 0]], dtype=np.float)
nw = int(w + self.factor * h)
nh = h
else:
matrix = np.array([[1, 0, 0], [self.factor, 1, 0]], dtype=np.float)
nw = w
nh = int(h + self.factor * w)
new_img = cv2.warpAffine(image, matrix, (nw, nh))
new_img = cv2.resize(new_img, (w, h))
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Rotate(_NaturalPerturb):
"""
Rotate an image of counter clockwise around its center.

Args:
angle (Union[float, int]): Degrees of counter clockwise. Suggested value range in [-60, 60].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> angle = 20
>>> trans = Rotate(angle)
>>> dst = trans(img)
"""

def __init__(self, angle=20, auto_param=False):
super(Rotate, self).__init__()
self.angle = check_param_multi_types('angle', angle, [int, float])
auto_param = check_param_type('auto_param', auto_param, bool)
if auto_param:
self.angle = np.random.uniform(0, 360)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, rotated image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, -self.angle, 1.0)
cos = np.abs(matrix[0, 0])
sin = np.abs(matrix[0, 1])

# Calculate new edge after rotated
nw = int((h * sin) + (w * cos))
nh = int((h * cos) + (w * sin))
# Adjust move distance of rotate matrix.
matrix[0, 2] += (nw / 2) - center[0]
matrix[1, 2] += (nh / 2) - center[1]
rotate = cv2.warpAffine(image, matrix, (nw, nh))
rotate = cv2.resize(rotate, (w, h))
new_img = self._original_format(rotate, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Perspective(_NaturalPerturb):
"""
Perform perspective transformation on a given picture.

Args:
ori_pos (list): Four points in original image.
dst_pos (list): The point coordinates of the 4 points in ori_pos after perspective transformation.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]]
>>> dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]]
>>> trans = Perspective(ori_pos, dst_pos)
>>> dst = trans(img)
"""

def __init__(self, ori_pos, dst_pos, auto_param=False):
super(Perspective, self).__init__()
ori_pos = check_param_type('ori_pos', ori_pos, list)
dst_pos = check_param_type('dst_pos', dst_pos, list)
self.ori_pos = np.float32(ori_pos)
self.dst_pos = np.float32(dst_pos)
self.auto_param = check_param_type('auto_param', auto_param, bool)

def _set_auto_param(self, w, h):
self.ori_pos = [[h * 0.25, w * 0.25], [h * 0.25, w * 0.75], [h * 0.75, w * 0.25], [h * 0.75, w * 0.75]]
self.dst_pos = [[np.random.uniform(0, h * 0.5), np.random.uniform(0, w * 0.5)],
[np.random.uniform(0, h * 0.5), np.random.uniform(w * 0.5, w)],
[np.random.uniform(h * 0.5, h), np.random.uniform(0, w * 0.5)],
[np.random.uniform(h * 0.5, h), np.random.uniform(w * 0.5, w)]]
self.ori_pos = np.float32(self.ori_pos)
self.dst_pos = np.float32(self.dst_pos)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
if self.auto_param:
self._set_auto_param(w, h)
matrix = cv2.getPerspectiveTransform(self.ori_pos, self.dst_pos)
new_img = cv2.warpPerspective(image, matrix, (w, h))
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Curve(_NaturalPerturb):
"""
Curve picture using sin method.

Args:
curves (union[float, int]): Divide width to curves of `2*math.pi`, which means how many curve cycles. Suggested
value range in [0.1. 5].
depth (union[float, int]): Amplitude of sin method. Suggested value not exceed 1/10 of the length of the
picture.
mode (str): Direction of deformation. Optional value is 'vertical' or 'horizontal'.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Examples:
>>> img = cv2.imread('x.png')
>>> curves =1
>>> depth = 10
>>> trans = Curve(curves, depth, mode='vertical')
>>> img_new = trans(img)
"""

def __init__(self, curves=3, depth=10, mode='vertical', auto_param=False):
super(Curve).__init__()
self.curves = check_value_non_negative('curves', curves)
self.depth = check_value_non_negative('depth', depth)
if mode in ['vertical', 'horizontal']:
self.mode = mode
else:
msg = "Value of param mode must be in ['vertical', 'horizontal']"
LOGGER.error(TAG, msg)
raise ValueError(msg)
self.auto_param = check_param_type('auto_param', auto_param, bool)

def _set_auto_params(self, height, width):
if self.auto_param:
self.curves = np.random.uniform(1, 5)
self.mode = np.random.choice(['vertical', 'horizontal'])
if self.mode == 'vertical':
self.depth = np.random.uniform(1, 0.1 * width)
else:
self.depth = np.random.uniform(1, 0.1 * height)

def __call__(self, image):
"""
Curve picture using sin method.

Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, curved image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
shape = image.shape
height, width = shape[:2]
if self.mode == 'vertical':
if len(shape) == 3:
image = np.transpose(image, [1, 0, 2])
else:
image = np.transpose(image, [1, 0])
src_x = np.zeros((height, width), np.float32)
src_y = np.zeros((height, width), np.float32)

for y in range(height):
for x in range(width):
src_x[y, x] = x
src_y[y, x] = y + self.depth * math.sin(x / (width / self.curves / 2 / math.pi))
img_new = cv2.remap(image, src_x, src_y, cv2.INTER_LINEAR)

if self.mode == 'vertical':
if len(shape) == 3:
img_new = np.transpose(img_new, [1, 0, 2])
else:
img_new = np.transpose(image, [1, 0])
new_img = self._original_format(img_new, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)

+ 1
- 0
mindarmour/reliability/model_fault_injection/__init__.py View File

@@ -1,4 +1,5 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at


+ 5
- 6
mindarmour/utils/_check_param.py View File

@@ -183,7 +183,7 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels):
def check_equal_length(para_name1, value1, para_name2, value2):
"""Check weather the two parameters have equal length."""
if len(value1) != len(value2):
msg = 'The dimension of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'\
msg = 'The dimension of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}' \
.format(para_name1, para_name2, len(value1), len(value2))
LOGGER.error(TAG, msg)
raise ValueError(msg)
@@ -193,7 +193,7 @@ def check_equal_length(para_name1, value1, para_name2, value2):
def check_equal_shape(para_name1, value1, para_name2, value2):
"""Check weather the two parameters have equal shape."""
if value1.shape != value2.shape:
msg = 'The shape of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'.\
msg = 'The shape of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'. \
format(para_name1, para_name2, value1.shape, value2.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
@@ -204,7 +204,7 @@ def check_norm_level(norm_level):
"""Check norm_level of regularization."""
if not isinstance(norm_level, (int, str)):
msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level))
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', np.inf]
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf]
if norm_level not in accept_norm:
msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level)
LOGGER.error(TAG, msg)
@@ -224,8 +224,7 @@ def normalize_value(value, norm_level):
numpy.ndarray, normalized value.

Raises:
NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2',
'inf', 'l1', 'l2']
NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', 'inf', 'l1', 'l2']
"""
norm_level = check_norm_level(norm_level)
ori_shape = value.shape
@@ -237,7 +236,7 @@ def normalize_value(value, norm_level):
elif norm_level in (2, '2', 'l2'):
norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
elif norm_level in (np.inf, 'inf'):
elif norm_level in (np.inf, 'inf', 'np.inf', 'linf'):
norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
else:


+ 5
- 4
tests/ut/python/fuzzing/test_coverage_metrics.py View File

@@ -75,11 +75,11 @@ def test_lenet_mnist_coverage_cpu():
model = Model(net)

# initialize fuzz test with training dataset
training_data = (np.random.random((10000, 10))*20).astype(np.float32)
training_data = (np.random.random((10000, 10)) * 20).astype(np.float32)

# fuzz test with original test data
# get test data
test_data = (np.random.random((2000, 10))*20).astype(np.float32)
test_data = (np.random.random((2000, 10)) * 20).astype(np.float32)
test_labels = np.random.randint(0, 10, 2000).astype(np.int32)

nc = NeuronCoverage(model, threshold=0.1)
@@ -118,6 +118,7 @@ def test_lenet_mnist_coverage_cpu():
print('NC of adv data is: ', nc_metric)
print('TKNC of adv data is: ', tknc_metrics)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@@ -130,11 +131,11 @@ def test_lenet_mnist_coverage_ascend():
model = Model(net)

# initialize fuzz test with training dataset
training_data = (np.random.random((10000, 10))*20).astype(np.float32)
training_data = (np.random.random((10000, 10)) * 20).astype(np.float32)

# fuzz test with original test data
# get test data
test_data = (np.random.random((2000, 10))*20).astype(np.float32)
test_data = (np.random.random((2000, 10)) * 20).astype(np.float32)
nc = NeuronCoverage(model, threshold=0.1)
nc_metric = nc.get_metrics(test_data)



+ 18
- 14
tests/ut/python/fuzzing/test_fuzzer.py View File

@@ -99,15 +99,17 @@ def test_fuzzing_ascend():
model = Model(net)
batch_size = 8
num_classe = 10
mutate_config = [{'method': 'Blur',
'params': {'auto_param': [True]}},
mutate_config = [{'method': 'GaussianBlur',
'params': {'ksize': [1, 2, 3, 5],
'auto_param': [True, False]}},
{'method': 'UniformNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'Contrast',
'params': {'factor': [2, 1]}},
{'method': 'Translate',
'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}},
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}},
{'method': 'Rotate',
'params': {'angle': [20, 90], 'auto_param': [False, True]}},
{'method': 'FGSM',
'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}
]
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}]

train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
# fuzz test with original test data
@@ -142,15 +144,17 @@ def test_fuzzing_cpu():
model = Model(net)
batch_size = 8
num_classe = 10
mutate_config = [{'method': 'Blur',
'params': {'auto_param': [True]}},
mutate_config = [{'method': 'GaussianBlur',
'params': {'ksize': [1, 2, 3, 5],
'auto_param': [True, False]}},
{'method': 'UniformNoise',
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
{'method': 'Contrast',
'params': {'factor': [2, 1]}},
{'method': 'Translate',
'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}},
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}},
{'method': 'Rotate',
'params': {'angle': [20, 90], 'auto_param': [False, True]}},
{'method': 'FGSM',
'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}
]
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}]
# initialize fuzz test with training dataset
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)



+ 0
- 126
tests/ut/python/fuzzing/test_image_transform.py View File

@@ -1,126 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image transform test.
"""
import numpy as np
import pytest

from mindarmour.utils.logger import LogUtil
from mindarmour.fuzz_testing.image_transform import Contrast, Brightness, \
Blur, Noise, Translate, Scale, Shear, Rotate

LOGGER = LogUtil.get_instance()
TAG = 'Image transform test'
LOGGER.set_level('INFO')


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_contrast():
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Contrast()
trans.set_params(auto_param=True)
_ = trans.transform(image)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_brightness():
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Brightness()
trans.set_params(auto_param=True)
_ = trans.transform(image)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_blur():
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Blur()
trans.set_params(auto_param=True)
_ = trans.transform(image)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_noise():
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Noise()
trans.set_params(auto_param=True)
_ = trans.transform(image)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_translate():
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Translate()
trans.set_params(auto_param=True)
_ = trans.transform(image)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_shear():
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Shear()
trans.set_params(auto_param=True)
_ = trans.transform(image)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_scale():
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Scale()
trans.set_params(auto_param=True)
_ = trans.transform(image)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_rotate():
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Rotate()
trans.set_params(auto_param=True)
_ = trans.transform(image)

+ 577
- 0
tests/ut/python/natural_robustness/test_natural_robustness.py View File

@@ -0,0 +1,577 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example for natural robustness methods."""

import pytest
import numpy as np
from mindspore import context

from mindarmour.natural_robustness import Translate, Curve, Perspective, Scale, Shear, Rotate, SaltAndPepperNoise, \
NaturalNoise, GaussianNoise, UniformNoise, MotionBlur, GaussianBlur, GradientBlur, Contrast, GradientLuminance


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_perspective():
"""
Feature: Test image perspective.
Description: Image will be transform for given perspective projection.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]]
dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]]
trans = Perspective(ori_pos, dst_pos)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_uniform_noise():
"""
Feature: Test image uniform noise.
Description: Add uniform image in image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = UniformNoise(factor=0.1)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_gaussian_noise():
"""
Feature: Test image gaussian noise.
Description: Add gaussian image in image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = GaussianNoise(factor=0.1)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_contrast():
"""
Feature: Test image contrast.
Description: Adjust image contrast.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = Contrast(alpha=0.3, beta=0)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_gaussian_blur():
"""
Feature: Test image gaussian blur.
Description: Add gaussian blur to image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = GaussianBlur(ksize=5)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_salt_and_pepper_noise():
"""
Feature: Test image salt and pepper noise.
Description: Add salt and pepper to image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = SaltAndPepperNoise(factor=0.01)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_translate():
"""
Feature: Test image translate.
Description: Translate an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = Translate(x_bias=0.1, y_bias=0.1)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_scale():
"""
Feature: Test image scale.
Description: Scale an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = Scale(factor_x=0.7, factor_y=0.7)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_shear():
"""
Feature: Test image shear.
Description: Shear an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = Shear(factor=0.2)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_rotate():
"""
Feature: Test image rotate.
Description: Rotate an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = Rotate(angle=20)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_curve():
"""
Feature: Test image curve.
Description: Transform an image with curve.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = Curve(curves=1.5, depth=1.5, mode='horizontal')
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_natural_noise():
"""
Feature: Test natural noise.
Description: Add natural noise to an.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_gradient_luminance():
"""
Feature: Test gradient luminance.
Description: Adjust image luminance.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
height, width = image.shape[:2]
point = (height // 4, width // 2)
start = (255, 255, 255)
end = (0, 0, 0)
scope = 0.3
bright_rate = 0.4
trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate,
mode='horizontal')
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_motion_blur():
"""
Feature: Test motion blur.
Description: Add motion blur to an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
angle = -10.5
i = 3
trans = MotionBlur(degree=i, angle=angle)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_gradient_blur():
"""
Feature: Test gradient blur.
Description: Add gradient blur to an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
image = np.random.random((32, 32, 3))
number = 10
h, w = image.shape[:2]
point = (int(h / 5), int(w / 5))
center = False
trans = GradientBlur(point, number, center)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_perspective_ascend():
"""
Feature: Test image perspective.
Description: Image will be transform for given perspective projection.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]]
dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]]
trans = Perspective(ori_pos, dst_pos)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_uniform_noise_ascend():
"""
Feature: Test image uniform noise.
Description: Add uniform image in image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = UniformNoise(factor=0.1)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_gaussian_noise_ascend():
"""
Feature: Test image gaussian noise.
Description: Add gaussian image in image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = GaussianNoise(factor=0.1)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_contrast_ascend():
"""
Feature: Test image contrast.
Description: Adjust image contrast.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = Contrast(alpha=0.3, beta=0)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_gaussian_blur_ascend():
"""
Feature: Test image gaussian blur.
Description: Add gaussian blur to image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = GaussianBlur(ksize=5)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_salt_and_pepper_noise_ascend():
"""
Feature: Test image salt and pepper noise.
Description: Add salt and pepper to image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = SaltAndPepperNoise(factor=0.01)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_translate_ascend():
"""
Feature: Test image translate.
Description: Translate an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = Translate(x_bias=0.1, y_bias=0.1)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_ascend_mindarmour
def test_scale_ascend():
"""
Feature: Test image scale.
Description: Scale an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = Scale(factor_x=0.7, factor_y=0.7)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_shear_ascend():
"""
Feature: Test image shear.
Description: Shear an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = Shear(factor=0.2)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_rotate_ascend():
"""
Feature: Test image rotate.
Description: Rotate an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = Rotate(angle=20)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_curve_ascend():
"""
Feature: Test image curve.
Description: Transform an image with curve.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = Curve(curves=1.5, depth=1.5, mode='horizontal')
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_natural_noise_ascend():
"""
Feature: Test natural noise.
Description: Add natural noise to an.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_gradient_luminance_ascend():
"""
Feature: Test gradient luminance.
Description: Adjust image luminance.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
height, width = image.shape[:2]
point = (height // 4, width // 2)
start = (255, 255, 255)
end = (0, 0, 0)
scope = 0.3
bright_rate = 0.4
trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate,
mode='horizontal')
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_motion_blur_ascend():
"""
Feature: Test motion blur.
Description: Add motion blur to an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
angle = -10.5
i = 3
trans = MotionBlur(degree=i, angle=angle)
dst = trans(image)
print(dst)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_gradient_blur_ascend():
"""
Feature: Test gradient blur.
Description: Add gradient blur to an image.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
image = np.random.random((32, 32, 3))
number = 10
h, w = image.shape[:2]
point = (int(h / 5), int(w / 5))
center = False
trans = GradientBlur(point, number, center)
dst = trans(image)
print(dst)

Loading…
Cancel
Save