Browse Source

Pre Merge pull request !260 from 高凡启/develop

pull/260/MERGE
高凡启 Gitee 3 years ago
parent
commit
553003d8dd
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
37 changed files with 4191 additions and 1 deletions
  1. +0
    -0
      examples/common/networks/mobilenet_v1/__init__.py
  2. +99
    -0
      examples/common/networks/mobilenet_v1/mobilenet_v1.py
  3. +1
    -0
      examples/natural_robustness/applications/__init__.py
  4. +0
    -0
      examples/natural_robustness/applications/classification/__init__.py
  5. +39
    -0
      examples/natural_robustness/applications/classification/defense_example.py
  6. +44
    -0
      examples/natural_robustness/applications/classification/evaluation_example.py
  7. +93
    -0
      examples/natural_robustness/applications/classification/preparation.py
  8. +1
    -0
      examples/natural_robustness/transform/__init__.py
  9. +54
    -0
      examples/natural_robustness/transform/time_series_transform_example.py
  10. +175
    -0
      mindarmour/natural_robustness/README.md
  11. +17
    -0
      mindarmour/natural_robustness/applications/__init__.py
  12. +26
    -0
      mindarmour/natural_robustness/applications/image_classification/__init__.py
  13. +199
    -0
      mindarmour/natural_robustness/applications/image_classification/defense.py
  14. +151
    -0
      mindarmour/natural_robustness/applications/image_classification/evaluation.py
  15. +238
    -0
      mindarmour/natural_robustness/applications/image_classification/implementation.py
  16. +214
    -0
      mindarmour/natural_robustness/applications/image_classification/sample_data_from_directory.py
  17. +32
    -0
      mindarmour/natural_robustness/base/__init__.py
  18. +52
    -0
      mindarmour/natural_robustness/base/abstract_equal.py
  19. +179
    -0
      mindarmour/natural_robustness/base/abstract_model.py
  20. +45
    -0
      mindarmour/natural_robustness/base/attacks/__init__.py
  21. +718
    -0
      mindarmour/natural_robustness/base/attacks/image_attacks.py
  22. +24
    -0
      mindarmour/natural_robustness/base/attacks/time_series_attack.py
  23. +158
    -0
      mindarmour/natural_robustness/base/common_defense.py
  24. +518
    -0
      mindarmour/natural_robustness/base/common_evaluation.py
  25. +306
    -0
      mindarmour/natural_robustness/base/data_structure.py
  26. +163
    -0
      mindarmour/natural_robustness/base/read_data.py
  27. +2
    -1
      mindarmour/natural_robustness/transform/__init__.py
  28. +23
    -0
      mindarmour/natural_robustness/transform/time_series/__init__.py
  29. +147
    -0
      mindarmour/natural_robustness/transform/time_series/corruption.py
  30. +33
    -0
      mindarmour/natural_robustness/transform/time_series/natural_perturb.py
  31. +13
    -0
      mindarmour/natural_robustness/utils/__init__.py
  32. +36
    -0
      mindarmour/natural_robustness/utils/custom_threading.py
  33. +229
    -0
      mindarmour/natural_robustness/utils/tools.py
  34. +0
    -0
      tests/ut/python/natural_robustness/classification/__init__.py
  35. +60
    -0
      tests/ut/python/natural_robustness/classification/test_defense.py
  36. +56
    -0
      tests/ut/python/natural_robustness/classification/test_evaluation.py
  37. +46
    -0
      tests/ut/python/natural_robustness/test_time_series_transform.py

+ 0
- 0
examples/common/networks/mobilenet_v1/__init__.py View File


+ 99
- 0
examples/common/networks/mobilenet_v1/mobilenet_v1.py View File

@@ -0,0 +1,99 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""mobile net v1."""

import mindspore.nn as nn
from mindspore.ops import operations as P


def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise,
activation='relu6'):
output = []
output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride,
pad_mode="same",
group=1 if not depthwise else in_channel))
output.append(nn.BatchNorm2d(out_channel))
if activation:
output.append(nn.get_activation(activation))
return nn.SequentialCell(output)


class MobileNetV1(nn.Cell):
"""
MobileNet V1 backbone.
"""

def __init__(self, class_num=1001, features_only=False):
super(MobileNetV1, self).__init__()
self.features_only = features_only
cnn = [
conv_bn_relu(3, 32, 3, 2, False), # Conv0

conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise
conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise
conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise
conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise

conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise
conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise
conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise
conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise

conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise
conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise
conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise
conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise

conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise

conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise
conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise
conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise
conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise
]

if self.features_only:
self.network = nn.CellList(cnn)
else:
self.network = nn.SequentialCell(cnn)
self.fc = nn.Dense(1024, class_num)

def construct(self, x):
"""mobilenet v1 construct."""
output = x
if self.features_only:
features = ()
for block in self.network:
output = block(output)
features = features + (output,)
return features
output = self.network(x)
output = P.ReduceMean()(output, (2, 3))
output = self.fc(output)
return output


def mobilenet_v1(class_num=1001):
return MobileNetV1(class_num)

+ 1
- 0
examples/natural_robustness/applications/__init__.py View File

@@ -0,0 +1 @@


+ 0
- 0
examples/natural_robustness/applications/classification/__init__.py View File


+ 39
- 0
examples/natural_robustness/applications/classification/defense_example.py View File

@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Robustness improvement example of trained mobilenet v1."""
import pprint

from examples.natural_robustness.applications.classification.preparation import \
create_model, get_file, image_process_ops, postprocess_func, preprocess_func
from mindarmour.natural_robustness.applications.image_classification import \
Defense

if __name__ == '__main__':
model = create_model("mobile_net.ckpt")
train_data_dir, test_data_dir = get_file()
defense = Defense(strategies=["Rotate", 'Brightness', 'Contrast'],
certificate_number=10,
augmentation_image_dir="aug_data",
name2index={'dogs': 0, 'wolves': 1},
batch_queue_buffer_size=2, mutate_print=False,
workers=5, thread_workers=2)
defense.set_model(model, predict_batch_size=100, train_func=None,
operations=image_process_ops(True),
device_target="Ascend", preprocess_func=preprocess_func,
postprocess_func=postprocess_func, epoch=70)
defense.data_reader.from_directory(test_data_dir, train_data_dir, True)
result = defense.defense(sample_rate=0.2, sample_number=4, iter_number=14)
pprint.pprint(result)

+ 44
- 0
examples/natural_robustness/applications/classification/evaluation_example.py View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Robustness evaluation example of trained mobilenet_v1."""

from examples.natural_robustness.applications.classification.preparation import \
create_model, get_file, postprocess_func, preprocess_func
from mindarmour.natural_robustness.applications.image_classification import \
Evaluation

if __name__ == '__main__':
model = create_model("./mobile_net.ckpt")
train_data_dir, test_data_dir = get_file()
evaluator = Evaluation(
strategies=['GaussianBlur', 'Brightness', 'Contrast', 'UniformNoise',
'GaussianNoise', 'SaltAndPepper', 'Rotate', 'Scale',
'Shear_x', 'Shear_y', 'Translate_x', 'Translate_y'],
support_data_number=10,
name2index={'dogs': 0, 'wolves': 1},
workers=2,
thread_workers=2,
batch_queue_buffer_size=2,
mutate_print=False)
evaluator.set_model(model, predict_batch_size=100,
preprocess_func=preprocess_func,
postprocess_func=postprocess_func)
evaluator.data_reader.from_directory(test_data_dir,
shuffle=True,
target_shape=None)
result = evaluator.evaluate()
print(evaluator.evaluate_parameter[0])
print(result)

+ 93
- 0
examples/natural_robustness/applications/classification/preparation.py View File

@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Preparation of robustness evaluation and improvement for mobilenet v1."""

import functools
import os
import subprocess

import numpy as np
from mindspore import nn, train
from mindspore.dataset.vision import c_transforms as C

from examples.common.networks.mobilenet_v1.mobilenet_v1 import mobilenet_v1


def get_file():
"""Download data and Unzip."""
cache_dir = "./dataset"
dataset_url = "http://mindspore-website.obs.cn-north-4.myhuaweicloud.com" \
"/notebook/datasets/intermediate/Canidae_data.zip"
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
file_name = os.path.basename(dataset_url)
zip_file_path = os.path.join(cache_dir, file_name)
file_path, _ = os.path.splitext(zip_file_path)
if not os.path.exists(file_path):
if not os.path.exists(zip_file_path):
subprocess.call(["wget", "-P", cache_dir, dataset_url])
subprocess.call(["unzip", "-d", file_path, zip_file_path])
train_dir = os.path.join(cache_dir, "Canidae_data/data/Canidae/train")
test_dir = os.path.join(cache_dir, "Canidae_data/data/Canidae/val/")
return train_dir, test_dir


def create_model(ckpt_path=None):
"""Create model and load checkpoints."""
lr = 0.001
momentum = 0.9
net = mobilenet_v1(2)
if ckpt_path is not None:
load_dict = train.load_checkpoint(ckpt_path)
train.load_param_into_net(net, load_dict)
print("load success")
else:
raise RuntimeError("Model weights not loaded.")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(net.trainable_params(), lr, momentum)
mobilenet = train.Model(net, net_loss, net_opt,
metrics={"Accuracy": nn.Accuracy()})
return mobilenet


def image_process_ops(is_train):
"""image process operations."""
if is_train:
trans = [
C.RandomCropDecodeResize(size=(299, 299),
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)]
else:
trans = [C.Resize(size=(299, 299))]
trans += [C.Rescale(1.0 / 255.0, 0.0),
C.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
C.HWC2CHW()]
return trans


def preprocess_func(data):
"""Preprocess function. Only process single function."""
res = functools.reduce(lambda x, y: y(x), [data] + image_process_ops(False))
return res


def postprocess_func(data):
"""Postprocess function."""
data = np.argmax(data, axis=-1).tolist()
return data

+ 1
- 0
examples/natural_robustness/transform/__init__.py View File

@@ -0,0 +1 @@


+ 54
- 0
examples/natural_robustness/transform/time_series_transform_example.py View File

@@ -0,0 +1,54 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example for natural robustness methods."""
import matplotlib.pyplot as plt
import numpy as np

from mindarmour.natural_robustness.transform.time_series import Miss, Noise


def noise_example():
"""An example for Noise interference in time series data."""
x = np.linspace(0, 100, 500)
y = np.sin(x)
trans = Noise(0.05)
target = trans(y)
plt.subplot(2, 1, 1)
plt.plot(x, y, c="r")
plt.legend(["origin"])
plt.subplot(2, 1, 2)
plt.plot(x, target)
plt.legend(["noise"])
plt.show()


def miss_example():
"""An example for Miss interference in time series data."""
x = np.linspace(0, 20, 100)
y = np.sin(x)
trans = Miss(0.2)
target = trans(y)
plt.subplot(2, 1, 1)
plt.scatter(x, y, c="r", s=1)
plt.legend(["origin"])
plt.subplot(2, 1, 2)
plt.scatter(x, target, s=1)
plt.legend(["miss"])
plt.show()


if __name__ == '__main__':
noise_example()
miss_example()

+ 175
- 0
mindarmour/natural_robustness/README.md View File

@@ -0,0 +1,175 @@
# Robustness Evaluation And Improvement of AI Model Based on Data Augmentation

## Introduction

For a certain transformation (interference) method, a certain dataset undergoes a certain degree of such transformation, and the AI model maintains the same prediction results. This ability is called the Robustness of the AI model. The robustness measurement of the AI model is to find a degree for a specific transformation (interference), and within such a degree range, the data input to the AI model undergoes such transformation, and the prediction result of the model remains unchanged. The improvement of the robustness of the AI model is to use the results of the model measurement to dynamically expand the data set according to certain rules, and then retrain the model with the expanded data set to achieve the result of improved robustness.

### Environment Requirements

- Hardware(CPU/Ascend/GPU)
- Prepare hardware environment with CPU, Ascend or GPU processor.
- Framework
- MindSpore
- For more information, please check the resources below:
- MindSpore Tutorials
- MindSpore Python API

## Quick Start

### 1. Robustness Evaluation

a. Preparatory work (download the dataset, create a model, define the data preprocessing operation flow)

```python
import functools
import os
import subprocess

import numpy as np
from mindspore import nn, train
from mindspore.dataset.vision import c_transforms as C

from examples.common.networks.mobilenet_v1.mobilenet_v1 import mobilenet_v1


def get_file():
"""Download data and Unzip."""
cache_dir = "./dataset"
dataset_url = "http://mindspore-website.obs.cn-north-4.myhuaweicloud.com" \
"/notebook/datasets/intermediate/Canidae_data.zip"
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
file_name = os.path.basename(dataset_url)
zip_file_path = os.path.join(cache_dir, file_name)
file_path, _ = os.path.splitext(zip_file_path)
if not os.path.exists(file_path):
if not os.path.exists(zip_file_path):
subprocess.call(["wget", "-P", cache_dir, dataset_url])
subprocess.call(["unzip", "-d", file_path, zip_file_path])
train_dir = os.path.join(cache_dir, "Canidae_data/data/Canidae/train")
test_dir = os.path.join(cache_dir, "Canidae_data/data/Canidae/val/")
return train_dir, test_dir


def create_model(ckpt_path=None):
"""Create model and load checkpoints."""
lr = 0.01
momentum = 0.9
net = mobilenet_v1(2)
if ckpt_path is not None:
load_dict = train.load_checkpoint(ckpt_path)
train.load_param_into_net(net, load_dict)
print("load success")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(net.trainable_params(), lr, momentum)
mobilenet = train.Model(net, net_loss, net_opt,
metrics={"Accuracy": nn.Accuracy()})
return mobilenet


def image_process_ops(is_train):
"""image process operations."""
if is_train:
trans = [
C.RandomCropDecodeResize(size=(299, 299),
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)]
else:
trans = [C.Resize(size=(299, 299))]
trans += [C.Rescale(1.0 / 255.0, 0.0),
C.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
C.HWC2CHW()]
return trans


def preprocess_func(data):
"""Preprocess function. Only process single function."""
res = functools.reduce(lambda x, y: y(x), [data] + image_process_ops(False))
return res


def postprocess_func(data):
"""Postprocess function."""
data = np.argmax(data, axis=-1).tolist()
return data
```

b、Evaluation

```python
model = create_model("./mobile_net.ckpt")
train_data_dir, test_data_dir = get_file()
evaluator = Evaluation(
strategies=['Blur', 'Brightness', 'Contrast', 'Noise', 'Rotate', 'Scale',
'Shear_x', 'Shear_y', 'Translate_x', 'Translate_y'],
certificate_number=10,
name2index={'dogs': 0, 'wolves': 1},
mutate_print=False)
evaluator.set_model(model, predict_batch_size=100,
preprocess_func=preprocess_func,
postprocess_func=postprocess_func)
evaluator.data_reader.from_directory(train_data_dir, shuffle=True)
result = evaluator.evaluate()
print(result)
```

c、Results

```python
{...,
'Brightness': {'amplitude': 6.116126731447102,
'boundary': (-5, 3),
'robustness_range': [-3.6561253390316844, 2.460001392415418],
'robustness_rate': 0.7645},
...,
'accuracy': 1.0}
```

The above dictionary is the result of an output, where the key is the type of **transformation (interference)** and **accuracy**, and for each interference type, the result is also a dictionary, where **amplitude** is the average magnitude of the transformation, **boundary** is the parameter range of the transformation, and **robustness_range** is the measurement to the average robustness interval, **robustness_rate** is the proportion of the robustness interval in the entire parameter range.

### 2、Robustness Improvement

a. Create a model, load the trained weights, set the corresponding compilation method, etc. The model needs to be trained directly through the train method of the model. (The code example is the same as 1.a)

b、Defense

```python
model = create_model("mobile_net.ckpt")
train_data_dir, test_data_dir = get_file()
defense = Defense(ori_image_dir=train_data_dir,
strategies=['Blur', 'Brightness', 'Contrast', 'Noise',
'Rotate', 'Scale',
'Shear_x', 'Shear_y', 'Translate_x',
'Translate_y'],
certificate_number=10,
augmentation_image_dir="aug_data",
name2index={'dogs': 0, 'wolves': 1},
batch_queue_buffer_size=2, mutate_print=False,
workers=2, thread_workers=2)
defense.set_model(model, predict_batch_size=100, train_func=None,
operations=image_process_ops(True),
device_target="CPU", preprocess_func=preprocess_func,
postprocess_func=postprocess_func, epoch=2)
defense.data_reader.from_directory(test_data_dir, train_data_dir, True)
result = defense.defense(
decrease_rate=0.2,
sample_number=4,
iter_number=5)
print(result)
```

3、Results

```python
{...,
'Blur': {'amplitude': [8.74074074074074, 9.88888888888889, 10.0],
'robustness_range': [[-4.370370370370369, 4.370370370370369],[-4.944444444444445, 4.944444444444445], [-5.0, 5.0]],
'robustness_rate': [0.8741, 0.9889, 1.0]},
...,
'accuracy': [0.55, 0.4666666666666667, 0.5166666666666667]}
```

The above dictionary is the result of a robustness-improved output, where the keys are the type of **transformation (interference)**, and the **accuracy**. For each disturbance type, the result is also a dictionary, but the value is a list. In a boosted result, the key actually means the same as the **evaluation** part, and the value includes multiple results. According to the index order of the list from small to large, the **values in the list are the results of measuring no retraining, retraining once, secondary**... .

+ 17
- 0
mindarmour/natural_robustness/applications/__init__.py View File

@@ -0,0 +1,17 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Application of robustness evaluation and improve instance.
"""

+ 26
- 0
mindarmour/natural_robustness/applications/image_classification/__init__.py View File

@@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This is an implementation of robustness evaluation and improvements ability
of AI model in image classification task.
"""

from .defense import Defense
from .evaluation import Evaluation
from .implementation import Attacks, Equal
from .sample_data_from_directory import DataReader

__all__ = ['Attacks', 'Equal', "DataReader", "Evaluation", "Defense", ]


+ 199
- 0
mindarmour/natural_robustness/applications/image_classification/defense.py View File

@@ -0,0 +1,199 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An implemented method for AI model defense function for image classification task."""

from .evaluation import Evaluation
from .implementation import Attacks, Equal, ModelAdapter
from .sample_data_from_directory import DataReader
from ...base import CommonDefense


class Defense(CommonDefense):
"""
Testing max interference amplitude within which results inferred by
specific AI model for a particular data will not change. Then sampling
data according robustness evaluating results. After that retrain AI
model by using sampled data to reach the goal of improve robustness of
AI model.

Args:
strategies(List[str]):mutate strategies, include pix_mutate and
none_pix_mutate_types.Pixel-level transform will change
just an input image and will leave any additional targets
such as masks, bounding boxes, and
key points unchanged. The list of pixel-level transform:
[pix_mutate: 'Blur','Brightness','Contrast','Noise'].
None_pix transform will simultaneously change both an
input image as well as additional targets such as masks,
bounding boxes, and key points. The list of
none_pix_mutate_types: ['Rotate','Scale',
'Shear_x', 'Shear_y','Translate_x', 'Translate_y'].
augmentation_image_dir(str): a folder where the augmentation_data
is temporarily stored.
certificate_number(int): number of data will be generated to evaluate
AI model robustness.
name2index(dict): a dict to transfer specific name into index.
eg: {"dog":0,"cat":1}
workers(int): number of multiprocess workers.
thread_workers(int): number of data batch which will be preprocessed.
batch_queue_buffer_size(int): number of queue size.
mutate_print(bool): whether print logs or not.

Examples:
>>> import functools
>>> import os
>>> import subprocess
>>>
>>> import numpy as np
>>> from mindspore import nn, train
>>> from mindspore.dataset.vision import c_transforms as C
>>> from examples.common.networks.mobilenet_v1.mobilenet_v1 import mobilenet_v1
>>> def image_process_ops(train):
>>> if train:
>>> trans = [
>>> C.RandomCropDecodeResize(size=(299, 299),
>>> scale=(0.08, 1.0),
>>> ratio=(0.75, 1.333)),
>>> C.RandomHorizontalFlip(prob=0.5),
>>> C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)]
>>> else:
>>> trans = [C.Resize(size=(299, 299))]
>>> trans += [C.Rescale(1.0 / 255.0, 0.0),
>>> C.Normalize(mean=[0.485, 0.456, 0.406],
>>> std=[0.229, 0.224, 0.225]),
>>> C.HWC2CHW()]
>>> return trans
>>> def preprocess_func(data):
>>> image_process_ops(False)
>>> res = functools.reduce(lambda x, y: y(x), [data] + image_process_ops(False))
>>> return res
>>> def postprocess_func(data):
>>> data = np.argmax(data, axis=-1).tolist()
>>> return data
>>> defense = Defense(strategies=["Blur"],
>>> certificate_number=10,
>>> augmentation_image_dir="aug_data",
>>> name2index={'dogs': 0, 'wolves': 1},
>>> batch_queue_buffer_size=2, mutate_print=False,
>>> workers=2, thread_workers=2)
>>> defense.set_model(model, predict_batch_size=100, train_func=None,
>>> operations=image_process_ops(True),
>>> device_target="CPU", preprocess_func=preprocess_func,
>>> postprocess_func=postprocess_func, epoch=2)
>>> defense.data_reader.from_directory("evaluate_data_dir", "defense_data_dir", True)
>>> result = defense.defense(
>>> sample_rate=0.2,
>>> sample_number=4,
>>> iter_number=5)
>>> print(result)

"""

def __init__(self, strategies, augmentation_image_dir, certificate_number, name2index,
workers, thread_workers, batch_queue_buffer_size,
mutate_print):
super(Defense, self).__init__(strategies, augmentation_image_dir)
self._evaluate = Evaluation(strategies, certificate_number, name2index,
workers, thread_workers,
batch_queue_buffer_size, mutate_print)
self._name2index = name2index
self._mutate_print = mutate_print
self._data_reader = DataReader(name2index, self._evaluate, self)

def set_model(self, model, predict_batch_size, train_func=None,
val_data_dir=None, operations=None, preprocess_func=None,
postprocess_func="default", device_target="CPU",
batch_size=100, epoch=50, repeat_size=2,
num_parallel_workers=2):
"""
Set model and a series of functions related to model. Compare to
set_model method in Evaluate. The parameters of this methods contains
not only predicting parameter, but also training parameter.

Args:
model(Model): Mindspore model.
predict_batch_size(int): batch_size for model inference.
train_func(Callable): A callable object which takes exactly 4
parameter, model, aug_data_dir, anno_path, logs. You can see
'train‘ method for more information about the definition of
this train function. If you only use Evaluation ability or
decide to use default train strategy, leave this function to
None. Otherwise please define a function which takes parameters
named 'model', 'aug_data_dir', 'anno_path' and 'logs' as inputs.
you may only use several of this parameters, but you need set
them all.
val_data_dir(str): Path to the target directory.It should contain
one subdirectory per class.
operations (Union[list[TensorOp], list[functions]]):
List of operations to be applied on the dataset. Operations are
applied in the order they appear in this list.
preprocess_func(Callable): A callable object which takes exactly 1
parameter. This function is used to process single image data,
after that, a bunch of data will ba batched and send to model
to get inference results. Default is return inputs.
postprocess_func(Callable): A callable object which takes exactly 1
parameter. This function is used to process AI model inference
results and return a list of image_classification index results.
device_target (str): The target device to run, support "Ascend", "GPU",
and "CPU". If device target is not set, the version of MindSpore
package is used.
batch_size (int or function): The number of rows each batch is created
with. An int or callable object which takes exactly 1 parameter,
BatchInfo.
epoch (int): Generally, total number of iterations on the data per epoch.
When dataset_sink_mode is set to true and sink_size>0, each epoch
sink sink_size steps on the data instead of total number of
iterations.
repeat_size (int): Number of times the dataset is going to be
repeated (default=None).
num_parallel_workers (int, optional): Number of threads used to
process the dataset in parallel (default=None, the value from the
configuration will be used).
"""
self._model_adapter = ModelAdapter(model=model,
predict_batch_size=predict_batch_size,
name2index=self._name2index,
preprocess_func=preprocess_func,
postprocess_func=postprocess_func,
mutate_print=self._mutate_print,
train_func=train_func,
val_data_dir=val_data_dir,
operations=operations,
device_target=device_target,
batch_size=batch_size,
epoch=epoch,
repeat_size=repeat_size,
num_parallel_workers=num_parallel_workers)
self._evaluate.model_adapter = self._model_adapter

@property
def evaluate(self):
"""
Evaluator instance, used to evaluate robustness of AI model in
robustness improve circle.
"""
return self._evaluate

@property
def data_reader(self):
return self._data_reader

@property
def equal_method(self):
return Equal()

@property
def attack_obj(self):
return Attacks()

+ 151
- 0
mindarmour/natural_robustness/applications/image_classification/evaluation.py View File

@@ -0,0 +1,151 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This is an implementation of robustness evaluation and improvements ability
of AI model in image classification task.
"""

from .implementation import Attacks, Equal, ModelAdapter
from .sample_data_from_directory import DataReader
from ...base.common_evaluation import CommonEvaluation
from ...base.abstract_model import AbstractModelAdapter


class Evaluation(CommonEvaluation):
"""
Evaluate robustness of AI model by testing max interference amplitude within
which results inferred by specific AI model for a particular data will not
change.

Args:
strategies(Union[list[str],tuple[str]):mutate strategies,
include pix_mutate and none_pix_mutate_types.
Pixel-level transform will change just an input image and will
leave any additional targets such as masks, bounding boxes, and
key points unchanged. The list of pixel-level transform:
pix_mutate: ['GaussianBlur', 'Brightness', 'Contrast', 'UniformNoise',
'GaussianNoise', 'SaltAndPepper'].
None_pix transform will simultaneously change
both an input image as well as additional targets such as masks,
bounding boxes, and key points. The list of none_pix_mutate_types:
['Rotate', 'Scale', 'Shear_x', 'Shear_y', 'Translate_x', 'Translate_y'].
support_data_number(int): number of data will be generated to evaluate
AI model robustness.
name2index(dict): a dict to transfer specific name into index.
eg: {"dog":0,"cat":1}
workers(int): number of multiprocess workers.
thread_workers(int): number of data batch which will be preprocessed.
batch_queue_buffer_size(int): number of queue size.
mutate_print(bool): whether print logs or not.

Examples:
>>> import functools
>>> import numpy as np
>>> import mindspore.dataset.vision.c_transforms as C
>>> def preprocess_func(data):
>>> image_process_ops = [C.Resize(size=(299, 299)),
>>> C.Rescale(1.0 / 255.0, 0.0),
>>> C.Normalize(mean=[0.485, 0.456, 0.406],
>>> std=[0.229, 0.224, 0.225]),
>>> C.HWC2CHW()]
>>> result = functools.reduce(lambda x, y: y(x), [data] + image_process_ops)
>>> return result
>>>
>>> def postprocess_func(data):
>>> data = np.argmax(data, axis=-1).tolist()
>>> return data
>>>
>>> evaluator = Evaluation(strategies=["Rotate"],
>>> support_data_number=10,
>>> name2index={'dogs': 0, 'wolves': 1},
>>> workers=2,
>>> thread_workers=2,
>>> batch_queue_buffer_size=2,
>>> mutate_print=False)
>>> evaluator.set_model(model, predict_batch_size=100,
>>> preprocess_func=preprocess_func,
>>> postprocess_func=postprocess_func)
>>> evaluator.data_reader.from_directory("data_dir", shuffle=True)
>>> result = evaluator.evaluate()
>>> print(result)
"""

def __init__(self, strategies, support_data_number,
name2index, workers=2, thread_workers=2,
batch_queue_buffer_size=2, mutate_print=False):
super(Evaluation, self).__init__(strategies, support_data_number,
workers, thread_workers,
batch_queue_buffer_size,
mutate_print)
self._name2index = name2index
self._model_adapter = None
self._equal = Equal()
self._attack = Attacks()
self._data_reader = DataReader(self._name2index, self)

def set_model(self, model, predict_batch_size=None, preprocess_func=None,
postprocess_func=None):
"""
Set model and a series of functions related to model.

Args:
model(Model): Mindspore model, which has been trained already.
predict_batch_size(int): Max batch_size for model inference. The
real batch size in model prediction is several times(at least1)
of the results of supporting image number multiply strategy numbers.
preprocess_func(Callable): A callable object which takes exactly 1
parameter. This function is used to process single image data,
after that, a bunch of data will be batched and send to model
to get inference results. If not provided, data will only be
batched before inference.
postprocess_func(Callable): A callable object which takes exactly 1
parameter. This function is used to process AI model inference
results and return a list of image_classification index results.
Notes:
preprocess_func is a function to process single image while
postprocess_func is a function to process a batch of image.
"""
self._model_adapter = ModelAdapter(model=model,
predict_batch_size=predict_batch_size,
name2index=self._name2index,
preprocess_func=preprocess_func,
postprocess_func=postprocess_func,
mutate_print=self.mutate_print)

@property
def model_adapter(self):
if self._model_adapter is None:
raise RuntimeError(
"Please use 'set_model' method to create ModelAdapter instance.")
return self._model_adapter

@model_adapter.setter
def model_adapter(self, value):
if not isinstance(value, AbstractModelAdapter):
raise TypeError(
"Value for model_adapter must inherent from AbstractModelAdapter")
self._model_adapter = value

@property
def data_reader(self):
return self._data_reader

@property
def equal_method(self):
return self._equal

@property
def attack_obj(self):
return self._attack

+ 238
- 0
mindarmour/natural_robustness/applications/image_classification/implementation.py View File

@@ -0,0 +1,238 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A set of classes or functions of data loading for image_classification task."""

import numpy as np
from mindspore import context, dataset as ds, dtype as mstype, Tensor
from mindspore.dataset.transforms import c_transforms as C
from mindspore.train.callback import CheckpointConfig, LossMonitor, \
ModelCheckpoint

from ...base.abstract_equal import AbstractEqual
from ...base.abstract_model import FunctionalModelAdapter
from ...base.attacks import AbstractAttack
from ...base.attacks.image_attacks import none_pix_mutate_types, pix_mutate
from ...base.data_structure import CommonParameter, get_image_data
from ...utils.tools import typed_property


class Parameter(CommonParameter):
image_path = typed_property("image_path", str)
target_shape = typed_property("target_shape", tuple)
true_label = typed_property("true_label", (str, int))
pred_label = typed_property("pred_label", (str, int))

def get_data(self):
return get_image_data(self)


class Attacks(AbstractAttack):
"""image_classification attack"""

@property
def transform_types(self):
"""A dict contains transform functions and their parameter range."""
return dict(pix_mutate, **none_pix_mutate_types)


class Equal(AbstractEqual):
"""Equal method."""

@staticmethod
def func(x, y):
"""Logics for results of classifications task which is acceptable."""
return str(x) == str(y)


def default_train_strategy(device_target, batch_size, epoch,
repeat_size, num_parallel_workers,
operations, name2index):
"""
Default train strategy.

Args:
device_target (str): The target device to run, support "Ascend", "GPU",
and "CPU". If device target is not set, the version of MindSpore
package is used.
batch_size (int or function): The number of rows each batch is created
with. An int or callable object which takes exactly 1 parameter,
BatchInfo.
epoch (int): Generally, total number of iterations on the data per epoch.
When dataset_sink_mode is set to true and sink_size>0, each epoch
sink sink_size steps on the data instead of total number of
iterations.
repeat_size (int): Number of times the dataset is going to be
repeated (default=None).
num_parallel_workers (int, optional): Number of threads used to
process the dataset in parallel (default=None, the value from the
configuration will be used).
operations (Union[list[TensorOp], list[functions]]):
List of operations to be applied on the dataset. Operations are
applied in the order they appear in this list.
name2index(dict): a dict to transfer specific name into index.
eg: {"dog":0,"cat":1}

Returns:
function, a standard train strategy to retrain model.

"""

def _wrap(model, aug_data_dir, anno_path, logs):
"""Parse aug_data_dir of Parse anno_path."""
print("annotation path is {}.".format(anno_path))
context.set_context(mode=context.GRAPH_MODE,
device_target=device_target)
label_process_ops = [C.TypeCast(mstype.int32)]
dataset = ds.ImageFolderDataset(dataset_dir=aug_data_dir,
class_indexing=name2index)
dataset = dataset.map(input_columns="label",
operations=label_process_ops,
num_parallel_workers=num_parallel_workers)
dataset = dataset.map(input_columns="image",
operations=operations,
num_parallel_workers=num_parallel_workers)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(batch_size, False).repeat(repeat_size)
iter_times = logs["iter_times"]
config_ck = CheckpointConfig(save_checkpoint_steps=1875,
keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(
prefix="checkpoint_mobile_net_itertimes_{}".format(iter_times),
config=config_ck)
model.train(epoch=epoch, train_dataset=dataset,
callbacks=[ckpoint_cb, LossMonitor()],
dataset_sink_mode=False)

return _wrap


class ModelAdapter(FunctionalModelAdapter):
"""
Implement of AbstractModelAdapter. Gather all related functions of
specific model to feed Evaluation class or Defense class.

Args:
model(Model): A Mindspore model which has been trained already and
loaded weights.
predict_batch_size(int): batch_size for model inference.
name2index(dict): a dict to transfer specific name into index.
eg: {"dog":0,"cat":1}
preprocess_func(Callable): A callable object which takes exactly 1
parameter. This function is used to process single image data,
after that, a bunch of data will ba batched and send to model
to get inference results. Default is return inputs.
postprocess_func(Callable): A callable object which takes exactly 1
parameter. This function is used to process AI model inference
results and return a list of image_classification index results.
train_func(Callable): A callable object which takes exactly 4
parameter, model, aug_data_dir, anno_path, logs. You can see
'train‘ method for more information about the definition of
this train function. If you only use Evaluation ability or
decide to use default train strategy, leave this function to
None. Otherwise please define a function which takes parameters
named 'model', 'aug_data_dir', 'anno_path' and 'logs' as inputs.
you may only use several of this parameters, but you need set
them all.
val_data_dir(str): Path to the aug_data directory.It contain one
subdirectory per class.
operations (Union[list[TensorOp], list[functions]]):
List of operations to be applied on the dataset. Operations are
applied in the order they appear in this list.
device_target (str): The target device to run, support "Ascend", "GPU",
and "CPU". If device target is not set, the version of MindSpore
package is used.
batch_size (int or function): The number of rows each batch is created
with. An int or callable object which takes exactly 1 parameter,
BatchInfo.
epoch (int): Generally, total number of iterations on the data per epoch.
When dataset_sink_mode is set to true and sink_size>0, each epoch
sink sink_size steps on the data instead of total number of
iterations.
repeat_size (int): Number of times the dataset is going to be
repeated (default=None).
num_parallel_workers (int, optional): Number of threads used to
process the dataset in parallel (default=None, the value from the
configuration will be used).
"""

def __init__(self, model, predict_batch_size, name2index,
preprocess_func=None, postprocess_func=None,
train_func=None, operations=None, device_target="CPU",
batch_size=100, epoch=50, repeat_size=2,
num_parallel_workers=2, mutate_print=True):
super(ModelAdapter, self).__init__(model, preprocess_func,
postprocess_func,
predict_batch_size,
mutate_print)
self.train_func = None
self.set_train_func(train_func, device_target, batch_size, epoch,
repeat_size, num_parallel_workers, operations,
name2index)

def set_train_func(self, train_func, device_target, batch_size, epoch,
repeat_size, num_parallel_workers, operations, name2index):
"""
Set train function. See class doc for more information.

Raises:
ValueError: if train_func is set but not callable.
"""
if train_func is None:
self.train_func = default_train_strategy(device_target,
batch_size,
epoch, repeat_size,
num_parallel_workers,
operations,
name2index)
elif callable(train_func):
self.train_func = train_func
else:
raise ValueError("train_func should be callable.")

def predict(self, data):
"""
Model predict func. Which will receive a list of pre_processed image
data, and then use AI model to get their inference results.
"""
data = np.array(data)
results = self.model.predict(Tensor(data, dtype=mstype.float32))
results = results.asnumpy()
return results

def train(self, aug_data_dir, anno_path, logs):
"""
In this method, user needs to implement a training method for model,
including but not limited to data transformation, data processing,
model training, and model parameter file saving.

Args:
aug_data_dir(str): Path to the aug_data directory.It contain one
subdirectory per class.
anno_path(str): A txt file which contains infos of training data.
logs: a dict like this:
{"iter_times": i,
"accuracy": train_accuracy}.

Raises:
TypeError: If function definition is not correct.
"""
model = self.model
try:
return self.train_func(model=model, aug_data_dir=aug_data_dir,
anno_path=anno_path, logs=logs)
except TypeError:
msg = "Available parameter names include aug_data_dir, anno_path, " \
"logs, please check the function definition."
raise TypeError(msg)

+ 214
- 0
mindarmour/natural_robustness/applications/image_classification/sample_data_from_directory.py View File

@@ -0,0 +1,214 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A set of classes or functions of data loading for image_classification task."""
import os
import random
import time

from .implementation import Parameter
from ...base import AbstractDataReader
from ...base.data_structure import ParameterList
from ...base.read_data import SampleData4Defense
from ...utils.tools import copytree


def generate_paths_labels(data_dir, shuffle, name2index):
"""
Convert standard image folders to paths and labels.

Args:
data_dir(str):path to the target directory. It should contain one
subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images inside
each of the subdirectories directory tree will be included in the
generator.
shuffle(bool): Whether to shuffle the data (default: True)
If set to False,sorts the data in alphanumeric order.
name2index(dict): A dict to transfer specific name into index.
eg: {"dog":0,"cat":1}

Returns:
List[str], a list of image paths.
List[int], a list of label index for image_classification task.
"""
image_paths = []
for root, _, files in os.walk(data_dir):
for file in files:
image_paths.append(os.path.normpath(os.path.join(root, file)))
if shuffle:
random.shuffle(image_paths)
labels = [item.split(os.sep)[-2] for item in image_paths]
labels = [name2index[item] for item in labels]
return image_paths, labels


def get_parameters(data_dir, shuffle, target_shape, name2index):
"""
Transfer image_classification data from directory which contains contains
several subdirectory into a ParameterList object. The name of Each
subdirectory is a name of class.

Args:
data_dir(str):path to the target directory. It should contain one
subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images inside
each of the subdirectories directory tree will be included in the
generator.
shuffle(bool): Whether to shuffle the data (default: True)
If set to False,sorts the data in alphanumeric order.
target_shape(tuple): image target shape.
name2index(dict): A dict to transfer specific name into index.
eg: {"dog":0,"cat":1}

Returns:
ParameterList,a object contains a list of Parameter instance.
"""
image_paths, labels = generate_paths_labels(data_dir, shuffle, name2index)
parameter_list = ParameterList()
for path, label in zip(image_paths, labels):
parameter = Parameter()
parameter.image_path = path
parameter.true_label = label
parameter.target_shape = target_shape
parameter_list.append(parameter)
return parameter_list


class Directory4Classification(SampleData4Defense):
"""
Used in defense process. Sample data for retraining AI model.

Args:
name2index(dict): A dict to transfer specific name into index.
eg: {"dog":0,"cat":1}
save_dir(str): A path where to store sampled image.
workers(int): Number of processes used to sample data.
transform_types(dict): A dict of functions and parameter boundary.
"""

def __init__(self, name2index, save_dir, workers, transform_types):
self.name2index = name2index
super().__init__(save_dir, workers, transform_types)

def generate_anno_path(self, parameters, results, data_dir, train_data_dir):
"""Copy origin data directory into augmented data directory."""
copytree(train_data_dir, data_dir, dirs_exist_ok=True)
anno_path = str(int(time.time() * 1000000))[-8:] + ".txt"
image_paths, labels = generate_paths_labels(data_dir, True, self.name2index)
strings = ["{},{}".format(*item) for item in zip(image_paths, labels)]
strings = "\n".join(strings)
with open(anno_path, "w") as f:
f.write(strings + '\n')
return anno_path

def sub_directory_func(self, parameter, data_dir):
"""
Method to create sub directory to store data. Only used in
classification task.
"""
label = self.index2name[parameter.true_label]
sub_data_dir = os.path.join(data_dir, label)
return sub_data_dir

@property
def index2name(self):
"""A dict whose keys are class indexes and values are class names."""
return {value: key for key, value in self.name2index.items()}


class DataReader(AbstractDataReader):
"""
Load data from directory or a numpy array and send to Evaluator object or
defense object.

Args:
name2index(dict): A dict to transfer specific name into index.
eg: {"dog":0,"cat":1}
evaluator(Evaluation): Evaluation class which used to calculate
robustness of AI model.
defense(Defense): Defense class which used to improve robustness
of AI model.
"""

def __init__(self, name2index, evaluator=None, defense=None):
super(DataReader, self).__init__(evaluator, defense)
self.name2index = name2index
self.io_method = None

def from_directory(self, evaluate_data_dir, defense_data_dir=None,
shuffle=True, target_shape=None):
"""User use this method to set dataset used by Evaluate or Defense object.

Args:
evaluate_data_dir(str): path to the target directory.
It should contain one subdirectory per class.
Any PNG, JPG, BMP, PPM or TIF images
inside each of the subdirectories directory tree
will be included. Here is an example:
::
.
└── image_folder_dataset_directory
├── class1
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── class2
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── class3
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── classN
├── ...
you need pass 'image_folder_dataset_directory'.This parameter
is used to create dataset for evaluation.
defense_data_dir(str): Directory structure is as same as data_dir.
This parameter is used to create dataset for defense(retraining
AI model). Only used in Defense mode.
shuffle(bool): Weather shuffle dataset or not.
target_shape(tuple): Shape that image data will be resized before
interfere
Examples:
>>> evaluator = Evaluation(strategies=["Rotate"],
>>> certificate_number=10,
>>> name2index={'dogs': 0, 'wolves': 1},
>>> mutate_print=False)
>>> evaluator.data_reader.from_directory("data_dir",shuffle=True)
"""
self.defense_data_dir = defense_data_dir
self.evaluate_data_dir = evaluate_data_dir
evaluate_parameter = get_parameters(evaluate_data_dir,
shuffle,
target_shape,
self.name2index)
self.evaluator.evaluate_parameter = evaluate_parameter
if defense_data_dir and self.defense is not None:
improve_parameter = get_parameters(defense_data_dir,
shuffle,
target_shape,
self.name2index)
self.defense.improve_parameter = improve_parameter

def write(self, save_dir, workers, transform_types, strategies,
parameters):
"""
Internal used, Save newly Generated image into a save_dir. This method is called by
Defense object before retraining AI model.
"""
writer = Directory4Classification(self.name2index, save_dir, workers,
transform_types)
data_root, anno_path = writer.write(strategies, parameters,
self.defense_data_dir)
return data_root, anno_path

+ 32
- 0
mindarmour/natural_robustness/base/__init__.py View File

@@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Abstract instance of model robustness evaluation and improve."""

from .abstract_equal import AbstractEqual
from .abstract_model import AbstractModelAdapter
from .data_structure import CommonParameter, ParameterList
from .common_defense import CommonDefense
from .common_evaluation import CommonEvaluation
from .read_data import AbstractDataReader

__all__ = [
"AbstractEqual",
"AbstractModelAdapter",
"CommonParameter",
"ParameterList",
"CommonDefense",
"CommonEvaluation",
"AbstractDataReader",
]

+ 52
- 0
mindarmour/natural_robustness/base/abstract_equal.py View File

@@ -0,0 +1,52 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Model robustness equal method. It changes as task changes. It is an common
api for different situation.
"""

import abc

import six


@six.add_metaclass(abc.ABCMeta)
class AbstractEqual:
"""Abstract Equal. User need to implement it """

@staticmethod
def func(x, y) -> bool:
"""Method to check whether the two elements are consistent."""

@classmethod
def is_equal(cls, x, y, method) -> bool:
"""
Check whether the two elements are consistent.
"""
res = cls._is_equal(x, y, method, func=cls.func)
return res

@staticmethod
def _is_equal(x, y, method="one2one", *, func):
"""Check whether the two numbers are consistent."""
is_equal = AbstractEqual._is_equal
funcs = {
"one2one": func,
"multi2one": lambda x, y: [
is_equal(item_x, y, method="one2one", func=func) for
item_x in x],
}
result = funcs[method](x, y)
return result

+ 179
- 0
mindarmour/natural_robustness/base/abstract_model.py View File

@@ -0,0 +1,179 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Abstract model adapter. It is an common API for models which has different
inputs type or outputs type, different way of preprocessing data, different
way of of inference data or training model. By instantiate AbstractModelAdapter
in different ways, this framework has the ability of adapting to different
situations.
"""

import abc
import time
from typing import List

import six
import numpy as np

from ..utils.tools import typed_property

wait_time, predict_time = 0, 0


def default_preprocess_func(data):
"""Default preprocess func."""
return data


def default_postprocess_func(model_prediction: np.ndarray):
"""Default postprocess func."""
return model_prediction


@six.add_metaclass(abc.ABCMeta)
class AbstractModelAdapter:
"""
Abstract instance of adapter for AI models.

Args:
model(Model): Any AI model written in Python
predict_batch_size(int): Batch size used in inference.
mutate_print(bool): Whether print logs or not.
"""
mutate_print = typed_property("mutate_print", bool)

def __init__(self, model, predict_batch_size, mutate_print=False):
self.model = model
self.predict_batch_size = predict_batch_size
self.mutate_print = mutate_print

def predict(self, data: List[np.ndarray]):
"""AI model inference API."""
try:
res = self.model.predict(data)
return res
except Exception:
raise RuntimeError(
"Your model get {}, witch does not match your model,"
"you need to overwrite this method,"
" and write match ones.".format(data))

def train(self, aug_data_dir: str, anno_path: str, logs: dict):
"""
In this method, user needs to implement a training method for model,
including but not limited to data transformation, data processing,
model training, and model parameter file saving.

Args:
aug_data_dir(str): path to the aug_data directory.It contain one
subdirectory per class.
anno_path(str): a txt file which contains infos of training data.
logs: a dict like this:
{"iter_times": i,
"accuracy": train_accuracy}.
"""

def predict_server(self, data_batch_queue, predict_res_queue):
"""
Use AI model to inference data. Input data is from data_batch_queue,
After inference, outputs is sending to predict_res_queue.
"""
global wait_time, predict_time
t0, t1, t2 = 0, 0, 0
while True:
if not self.mutate_print:
t0 = time.time()
data = data_batch_queue.get()
if not self.mutate_print:
t1 = time.time()
wait_time += t1 - t0
print(" wait to get data cost {}s, "
"current number of batches in queue is "
"{}.".format(t1 - t0, data_batch_queue.qsize()))
if data == "end":
predict_res_queue.put("end")
if not self.mutate_print:
print("************************************************")
print("**model wait data cost {}s totally, model predict"
" cost {}s.".format(wait_time, predict_time))
print("************************************************")
break
else:
parameter_batch, processed_image_res = data
res = self.predict(processed_image_res)
del processed_image_res
if not self.mutate_print:
t2 = time.time()
predict_time += t2 - t1
print("model prediction cost {}s".format(t2 - t1))
predict_res_queue.put((parameter_batch, res))


class FunctionalModelAdapter(AbstractModelAdapter):
"""
Abstract instance of adapter for AI models. Functional API, user define
needed functions first and then transfer to instance of this class.

Args:
model(Model): Any AI model written in Python。
preprocess_func(Callable): A callable object which takes exactly 1
parameter. This function is used to process single image data,
after that, a bunch of data will ba batched and send to model
to get inference results. Default is return inputs.
postprocess_func(Callable): A callable object which takes exactly 1
parameter. This function is used to process AI model inference
results and return a list of predictions.
predict_batch_size(int): Batch size used in inference.
mutate_print(bool): Whether print logs or not.
"""
mutate_print = typed_property("mutate_print", bool)

def __init__(self, model, preprocess_func, postprocess_func,
predict_batch_size, mutate_print=False):
super().__init__(model, predict_batch_size, mutate_print)
if preprocess_func is None:
self.preprocess = default_preprocess_func
else:
self.preprocess = preprocess_func
if postprocess_func is None:
self.postprocess = default_postprocess_func
else:
self.postprocess = postprocess_func


class ClassicModelAdapter(AbstractModelAdapter):
"""
Abstract instance of adapter for AI models. User define preprocess func
needed functions first and then transfer to instance of this class.
"""

@classmethod
@abc.abstractmethod
def preprocess(cls, data):
"""
This function is used to process single image data,
after that, a bunch of data will ba batched and send to model
to get inference results. Default is return inputs.
"""
return data

@classmethod
@abc.abstractmethod
def postprocess(cls, data):
"""
This method is used to process AI model inference
results and return a list of predictions.
"""
return data

+ 45
- 0
mindarmour/natural_robustness/base/attacks/__init__.py View File

@@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This model provide several methods to interference data."""

import abc


class AbstractAttack:
"""Abstract Attack."""

def __init__(self):
self._transform_types = None
self._transform_method_boundary = None
self._radius_boundary_amplitude = None

@property
@abc.abstractmethod
def transform_types(self):
"""Data mutate methods."""

@transform_types.setter
def transform_types(self, value):
self._transform_types = value

@property
def transform_method_boundary(self):
return {key: value["boundary"] for key, value
in self.transform_types.items()}

@property
def radius_boundary_amplitude(self):
return {key: value[1] - value[0] for key, value
in self._transform_method_boundary.items()}

+ 718
- 0
mindarmour/natural_robustness/base/attacks/image_attacks.py View File

@@ -0,0 +1,718 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provide several methods to interference image data."""

import math

import cv2
import numpy as np
from mindspore.dataset.vision.py_transforms_util import hwc_to_chw, is_numpy, \
to_pil
from PIL import Image, ImageEnhance

from mindarmour.natural_robustness.transform.image.blur import GaussianBlur
from mindarmour.natural_robustness.transform.image.corruption import \
GaussianNoise, NaturalNoise, SaltAndPepperNoise, UniformNoise
from mindarmour.utils._check_param import check_numpy_param, \
check_param_in_range, check_param_multi_types


def chw_to_hwc(img):
"""
Transpose the input image; shape (C, H, W) to shape (H, W, C).

Args:
img (numpy.ndarray): Image to be converted.

Returns:
img (numpy.ndarray), Converted image.
"""
if is_numpy(img):
return img.transpose(1, 2, 0).copy()
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def is_hwc(img):
"""
Check if the input image is shape (H, W, C).

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3:
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def is_chw(img):
"""
Check if the input image is shape (H, W, C).

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3:
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def is_rgb(img):
"""
Check if the input image is RGB.

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is RGB.
"""
if is_numpy(img):
img_shape = np.shape(img)
if len(np.shape(img)) == 3 and (
img_shape[0] == 3 or img_shape[2] == 3):
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def is_normalized(img):
"""
Check if the input image is normalized between 0 to 1.

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is normalized between 0 to 1.
"""
if is_numpy(img):
minimal = np.min(img)
maximun = np.max(img)
if minimal >= 0 and maximun <= 1:
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))


class ImageTransform:
"""
The abstract base class for all image transform classes.
"""

def __init__(self):
pass

def _check(self, image):
""" Check image format. If input image is RGB and its shape
is (C, H, W), it will be transposed to (H, W, C). If the value
of the image is not normalized , it will be normalized between 0 to 1."""
rgb = is_rgb(image)
chw = False
gray3dim = False
normalized = is_normalized(image)
if rgb:
chw = is_chw(image)
if chw:
image = chw_to_hwc(image)
else:
image = image
else:
if len(np.shape(image)) == 3:
gray3dim = True
image = image[0]
else:
image = image
if normalized:
image = image * 255
return rgb, chw, normalized, gray3dim, np.uint8(image)

def _original_format(self, image, chw, normalized, gray3dim):
""" Return transformed image with original format. """
if not is_numpy(image):
image = np.array(image)
if chw:
image = hwc_to_chw(image)
if normalized:
image = image / 255
if gray3dim:
image = np.expand_dims(image, 0)
return image

def __call__(self, image):
pass


class Contrast(ImageTransform):
"""
Contrast of an image.

Args:
factor (Union[float, int]): Control the contrast of an image. If 1.0,
gives the original image. If 0, gives a gray image. Default: 1.
"""

def __init__(self, factor=1):
super(Contrast, self).__init__()
self.set_params(factor)

def set_params(self, factor=1, auto_param=False):
"""
Set contrast parameters.

Args:
factor (Union[float, int]): Control the contrast of an image. If 1.0
gives the original image. If 0 gives a gray image. Default: 1.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor = np.random.uniform(-5, 5)
else:
self.factor = check_param_multi_types('factor', factor,
[int, float])

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
image = to_pil(image)
img_contrast = ImageEnhance.Contrast(image)
trans_image = img_contrast.enhance(self.factor)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)

return trans_image.astype(ori_dtype)


class Brightness(ImageTransform):
"""
Brightness of an image.

Args:
factor (Union[float, int]): Control the brightness of an image. If 1.0
gives the original image. If 0 gives a black image. Default: 1.
"""

def __init__(self, factor=1):
super(Brightness, self).__init__()
self.set_params(factor)

def set_params(self, factor=1, auto_param=False):
"""
Set brightness parameters.

Args:
factor (Union[float, int]): Control the brightness of an image. If 1
gives the original image. If 0 gives a black image. Default: 1.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor = np.random.uniform(0, 5)
else:
self.factor = check_param_multi_types('factor', factor,
[int, float])

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
image = to_pil(image)
img_contrast = ImageEnhance.Brightness(image)
trans_image = img_contrast.enhance(self.factor)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Translate(ImageTransform):
"""
Translate an image.

Args:
x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length.
Default: 0.
y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide.
Default: 0.
"""

def __init__(self, x_bias=0, y_bias=0):
super(Translate, self).__init__()
self.set_params(x_bias, y_bias)

def set_params(self, x_bias=0, y_bias=0, auto_param=False):
"""
Set translate parameters.

Args:
x_bias (Union[float, int]): X-direction translation, and x_bias should be in range of (-1, 1). Default: 0.
y_bias (Union[float, int]): Y-direction translation, and y_bias should be in range of (-1, 1). Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
x_bias = check_param_in_range('x_bias', x_bias, -1, 1)
y_bias = check_param_in_range('y_bias', y_bias, -1, 1)
self.auto_param = auto_param
if auto_param:
self.x_bias = np.random.uniform(-0.3, 0.3)
self.y_bias = np.random.uniform(-0.3, 0.3)
else:
self.x_bias = check_param_multi_types('x_bias', x_bias,
[int, float])
self.y_bias = check_param_multi_types('y_bias', y_bias,
[int, float])

def __call__(self, image):
"""
Transform the image.

Args:
image(numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
img = to_pil(image)
image_shape = np.shape(image)
self.x_bias = image_shape[1] * self.x_bias
self.y_bias = image_shape[0] * self.y_bias
trans_image = img.transform(img.size, Image.AFFINE,
(1, 0, self.x_bias, 0, 1, self.y_bias))
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Scale(ImageTransform):
"""
Scale an image in the middle.

Args:
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x.
Default: 1.
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y.
Default: 1.
"""

def __init__(self, factor_x=1, factor_y=1):
super(Scale, self).__init__()
self.set_params(factor_x, factor_y)

def set_params(self, factor_x=1, factor_y=1, auto_param=False):

"""
Set scale parameters.

Args:
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x.
Default: 1.
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y.
Default: 1.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor_x = np.random.uniform(0.7, 3)
self.factor_y = np.random.uniform(0.7, 3)
else:
self.factor_x = check_param_multi_types('factor_x', factor_x,
[int, float])
self.factor_y = check_param_multi_types('factor_y', factor_y,
[int, float])

def __call__(self, image):
"""
Transform the image.

Args:
image(numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
rgb, chw, normalized, gray3dim, image = self._check(image)
if rgb:
h, w, _ = np.shape(image)
else:
h, w = np.shape(image)
move_x_centor = w / 2 * (1 - self.factor_x)
move_y_centor = h / 2 * (1 - self.factor_y)
img = to_pil(image)
trans_image = img.transform(img.size, Image.AFFINE,
(self.factor_x, 0, move_x_centor,
0, self.factor_y, move_y_centor))
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Shear(ImageTransform):
"""
Shear an image, for each pixel (x, y) in the sheared image, the new value is
taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then
the sheared image will be rescaled to fit original size.

Args:
factor_x (Union[float, int]): Shear factor of horizontal direction.
Default: 0.
factor_y (Union[float, int]): Shear factor of vertical direction.
Default: 0.

"""

def __init__(self, factor_x=0, factor_y=0):
super(Shear, self).__init__()
self.set_params(factor_x, factor_y)

def set_params(self, factor_x=0, factor_y=0, auto_param=False):
"""
Set shear parameters.

Args:
factor_x (Union[float, int]): Shear factor of horizontal direction.
Default: 0.
factor_y (Union[float, int]): Shear factor of vertical direction.
Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if factor_x != 0 and factor_y != 0:
msg = 'At least one of factor_x and factor_y is zero.'
raise ValueError(msg)
if auto_param:
if np.random.uniform(-1, 1) > 0:
self.factor_x = np.random.uniform(-2, 2)
self.factor_y = 0
else:
self.factor_x = 0
self.factor_y = np.random.uniform(-2, 2)
else:
self.factor_x = check_param_multi_types('factor', factor_x,
[int, float])
self.factor_y = check_param_multi_types('factor', factor_y,
[int, float])

def __call__(self, image):
"""
Transform the image.

Args:
image(numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
rgb, chw, normalized, gray3dim, image = self._check(image)
img = to_pil(image)
if rgb:
h, w, _ = np.shape(image)
else:
h, w = np.shape(image)
if self.factor_x != 0:
boarder_x = [0, -w, -self.factor_x * h, -w - self.factor_x * h]
min_x = min(boarder_x)
max_x = max(boarder_x)
scale = (max_x - min_x) / w
move_x_cen = (w - scale * w - scale * h * self.factor_x) / 2
move_y_cen = h * (1 - scale) / 2
else:
boarder_y = [0, -h, -self.factor_y * w, -h - self.factor_y * w]
min_y = min(boarder_y)
max_y = max(boarder_y)
scale = (max_y - min_y) / h
move_y_cen = (h - scale * h - scale * w * self.factor_y) / 2
move_x_cen = w * (1 - scale) / 2
trans_image = img.transform(img.size, Image.AFFINE,
(scale, scale * self.factor_x, move_x_cen,
scale * self.factor_y, scale, move_y_cen))
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


class Rotate(ImageTransform):
"""
Rotate an image of degrees counter clockwise around its center.

Args:
angle(Union[float, int]): Degrees counter clockwise. Default: 0.
"""

def __init__(self, angle=0):
super(Rotate, self).__init__()
self.set_params(angle)

def set_params(self, angle=0, auto_param=False):
"""
Set rotate parameters.

Args:
angle(Union[float, int]): Degrees counter clockwise. Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.angle = np.random.uniform(0, 360)
else:
self.angle = check_param_multi_types('angle', angle, [int, float])

def __call__(self, image):
"""
Transform the image.

Args:
image(numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
image = check_numpy_param('image', image)
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
img = to_pil(image)
trans_image = img.rotate(self.angle, expand=False)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image.astype(ori_dtype)


def rotate(degree):
"""
Rotate an image of degrees counter clockwise around its center.

Args:
degree(Union[float, int]): Degrees counter clockwise. Default: 0.
"""
return Rotate(degree)


def brightness(factor):
"""
Brightness of an image.

Args:
factor (Union[float, int]): Control the brightness of an image. If 0
gives the original image. If positive value gives a brighter image.
If negative value gives the darker image.
"""
return Brightness(math.exp(factor))


def gaussian_blur(factor):
"""
Blurs the image using Gaussian blur filter.

Args:
factor (int): Size of gaussian kernel.

"""
factor = abs(int(factor))
if factor == 0:
return lambda x: np.copy(x)
else:
return GaussianBlur(factor)


def uniform_noise(factor):
"""
Add uniform noise of an image.

Args:
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in
[0.001, 0.15].

"""
return UniformNoise(abs(factor))


def gaussian_noise(factor):
"""
Add gaussian noise of an image.

Args:
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in
[0.001, 0.15].

"""
if factor == 0:
return lambda x: np.copy(x)
else:
return GaussianNoise(abs(factor))


def salt_and_pepper_noise(factor):
"""
Add salt and pepper noise of an image.

Args:
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in
[0.001, 0.15].
"""
return SaltAndPepperNoise(abs(factor))


def nature_noise(factor):
"""
Add natural noise to an image.

Args:
factor (float): Noise density, the proportion of noise blocks per unit pixel area.
"""
return NaturalNoise(abs(factor))


def contrast(factor):
"""
Contrast of an image.

Args:
factor (Union[float, int]): Control the brightness of an image. If 0
gives the original image. If positive value gives a trans image.
If negative value gives the gray image.
"""
return Contrast(math.exp(factor))


def scale(factor):
"""
Brightness of an image.

Args:
factor (Union[float, int]): Rescale in X-direction,
x=x/(1 - abs(factor).
"""
return Scale(math.exp(-factor), math.exp(-factor))


def shear_x(factor):
"""
Shear an image, for each pixel (x, y) in the sheared image, the new value is
taken from a position (x+factor_x*y, y) in the origin image. Then
the sheared image will be rescaled to fit original size.

Args:
factor(Union[float, int]): Shear factor of horizontal direction.
Default: 0.
"""
return Shear(factor, 0)


def shear_y(factor):
"""
Shear an image, for each pixel (x, y) in the sheared image, the new value is
taken from a position (x, factor_y*x+y) in the origin image. Then
the sheared image will be rescaled to fit original size.

Args:
factor(Union[float, int]): Shear factor of horizontal direction.
Default: 0.
"""
return Shear(0, factor)


def translate_x(factor):
"""
Translate an image.

Args:
factor(Union[int, float]): Y-direction translation, y = y + y_bias*image_wide.
Default: 0.
"""
return Translate(factor, 0)


def translate_y(factor):
"""
Translate an image.

Args:
factor(Union[int, float]): Y-direction translation, y = y + y_bias*image_wide.
Default: 0.
"""
return Translate(0, factor)


pix_mutate = {
'GaussianBlur': {'func': gaussian_blur, 'boundary': (-30, 30)},
'Brightness': {'func': brightness, 'boundary': (-2, 1.5)},
'Contrast': {'func': contrast, 'boundary': (-3, 3)},
'UniformNoise': {'func': uniform_noise, 'boundary': (-1.5, 1.5)},
'GaussianNoise': {'func': gaussian_noise, 'boundary': (-1, 1)},
"SaltAndPepper": {'func': salt_and_pepper_noise, 'boundary': (-0.5, 0.5)},
}

none_pix_mutate_types = {
'Rotate': {'func': rotate, 'boundary': (-180, 180)},
'Scale': {'func': scale, 'boundary': (-1.5, 1.5)},
'Shear_x': {'func': shear_x, 'boundary': (-2, 2)},
'Shear_y': {'func': shear_y, 'boundary': (-2, 2)},
'Translate_x': {'func': translate_x, 'boundary': (-0.5, 0.5)},
'Translate_y': {'func': translate_y, 'boundary': (-0.5, 0.5)}}

if __name__ == '__main__':
image = cv2.imread("0.JPEG")
new_dict = dict(pix_mutate, **none_pix_mutate_types)
font = cv2.FONT_HERSHEY_SIMPLEX
i = 0
for name in new_dict:
info = new_dict.get(name)
func = info.get("func")
boundary = info.get("boundary")
degrees = np.linspace(*boundary, num=9)
print(degrees)
images = []
for index, degree in enumerate(degrees):
new_image = func(degree)(image)
degree = round(degree, 3)
cv2.putText(new_image, "{}_{}".format(index, degree), (20, 20),
font, 0.5, (0, 0, 255), 2)
images.append(new_image)
print(len(images))
images_ = np.concatenate(images, axis=1)
cv2.imwrite("{}_{}.jpg".format(i, name), images_)
i += 1
new_dict = dict(pix_mutate, **none_pix_mutate_types)
print(new_dict.keys())

+ 24
- 0
mindarmour/natural_robustness/base/attacks/time_series_attack.py View File

@@ -0,0 +1,24 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provide several methods to interference time series data."""

from mindarmour.natural_robustness.transform.time_series import Miss, Noise

funcs = {
'Noise': {'func': Noise,
'boundary': (-0.2, 0.2)},
"Miss": {"func": Miss,
'boundary': (-0.2, 0.2)}
}

+ 158
- 0
mindarmour/natural_robustness/base/common_defense.py View File

@@ -0,0 +1,158 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
AI Model robustness defense method. Improve robustness of AI models by using
selective data augmentation and retraining AI model.
"""

import abc
import os
from collections import defaultdict

from .abstract_equal import AbstractEqual
from .abstract_model import AbstractModelAdapter
from .attacks import AbstractAttack
from .common_evaluation import CommonEvaluation
from .data_structure import ParameterList
from .read_data import AbstractDataReader


class CommonDefense:
"""
Improve robustness of AI models by using selective data augmentation to
retrain AI model.

Args:
strategies(List[str]): Mutate strategies names. They need be defined first
and then record by instancing AbstractAttack
(natural_robustness.base.attacks.AbstractAttack.).
augmentation_image_dir(str): A folder where the augmentation_data
is temporarily stored.
"""

def __init__(self, strategies, augmentation_image_dir):
self.strategies = strategies
self.augmentation_image_dir = augmentation_image_dir

self._improve_parameter = None
self._model_adapter = None

def defense(self, sample_rate, sample_number, iter_number,
pre_train=False):
"""
ModelRobustnessBoosting robustness of model by retraining model
in augmentative data iteratively.
Args:
sample_number(int): Sample number of each data during augmentation.
sample_rate(double):Decreased rate of newly generated data.
iter_number(int): Iterations number of retrain model.
pre_train(bool): Robustness improve from scratch.

Returns:
dict, A dict with accuracy, amplitude, boundary, robustness_range,
and robustness_rate of each mutate strategy for each iterations.
"""
if pre_train:
data_dir = self.data_reader.defense_data_dir
anno_path = self.data_reader.get_anno_path(data_dir)
self.model_adapter.train(data_dir, anno_path,
logs={"item_numbers": -1})
all_certi_result4eval = defaultdict(dict)
all_certi_result4train = defaultdict(dict)
if self.augmentation_image_dir is None:
self.augmentation_image_dir = os.path.expanduser(
os.path.join("~", "augmentation_image_dir"))
if self.improve_parameter.is_empty:
raise RuntimeError("Data Empty.")
self.improve_parameter.init_sample_number(sample_number)
for i in range(iter_number + 1):
if len(self.improve_parameter) == 0:
break
parameters = self.improve_parameter.random_sample(sample_rate)
parameters = self.evaluate.cal_robustness(parameters, True)
parameters.modify_sample_number()
train_result = self.evaluate.statistic(parameters)
val_result = self.evaluate.evaluate()
self.improve_parameter.extend(parameters)
val_result.update({"training": "training_{}".format(i)})
train_result.update({"training": "training_{}".format(i)})
self.update_results(all_certi_result4eval, val_result)
self.update_results(all_certi_result4train, train_result)
logs = {
"iter_times": i,
"accuracy": train_result["accuracy"],
}
if i < iter_number:
aug_data_dir, anno_path = self.data_reader.write(
save_dir=self.augmentation_image_dir,
workers=os.cpu_count(),
transform_types=self.attack_obj.transform_types,
strategies=self.strategies, parameters=parameters)
self.model_adapter.train(aug_data_dir, anno_path, logs)
if os.path.exists(anno_path):
os.remove(anno_path)
return {"train": all_certi_result4train, "val": all_certi_result4eval}

def update_results(self, results, certificate_result):
"""Update improve results."""
results["accuracy"] = results.get("accuracy", [])
results["accuracy"].append(certificate_result['accuracy'])
for strategy in self.strategies:
for key in ['amplitude', 'robustness_range', 'robustness_rate']:
results[strategy][key] = results[strategy].get(key, [])
value = certificate_result[strategy][key]
results[strategy][key].append(value)

@property
def improve_parameter(self) -> ParameterList:
"""Data used in defense process."""
return self._improve_parameter

@improve_parameter.setter
def improve_parameter(self, value: ParameterList):
"""Data used in defense process."""
if not isinstance(value, ParameterList):
raise ValueError("improve_parameter should be instance of ParameterList.")
self._improve_parameter = value

@property
def model_adapter(self) -> AbstractModelAdapter:
if self._model_adapter is None:
raise RuntimeError("please use 'set_model()' method first.")
return self._model_adapter

@model_adapter.setter
def model_adapter(self, model_adapter: AbstractModelAdapter):
self._model_adapter = model_adapter

@property
@abc.abstractmethod
def evaluate(self) -> CommonEvaluation:
"""Evaluate tools to used to evaluate model robustness."""

@property
@abc.abstractmethod
def data_reader(self) -> AbstractDataReader:
"""Return a sub instance of AbstractDataReader."""

@property
@abc.abstractmethod
def equal_method(self) -> AbstractEqual:
"""Return a sub instance of AbstractEqual."""

@property
@abc.abstractmethod
def attack_obj(self) -> AbstractAttack:
"""Return a sub instance of AbstractAttack."""

+ 518
- 0
mindarmour/natural_robustness/base/common_evaluation.py View File

@@ -0,0 +1,518 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Evaluate robustness of AI model by testing max interference amplitude within
which results inferred by specific AI model for a particular data will not
change.
"""

import abc
import copy
import math
import multiprocessing
import pprint
import queue
import sys
import traceback
import warnings
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce, wraps

import six
import numpy as np

from mindarmour.utils._check_param import check_param_type
from .abstract_equal import AbstractEqual
from .abstract_model import AbstractModelAdapter
from .attacks import AbstractAttack
from .data_structure import ParameterList, Score
from .read_data import AbstractDataReader
from ..utils import custom_threading
from ..utils.tools import calculate_time, catch_exception, linspace

wait_time = 0
predict_time = 0


def _check_parameter(function):
"""
A decorator used to decorate the validation of the
sampling_with_perturbation parameter.
"""

@wraps(function)
def _wrap(self, strategy, function_interval, *args, **kwargs):
transform_types = self.transform_types
if strategy not in transform_types:
raise ValueError(
"Unknown Perturbation Mode:{}.".format(strategy))
radius_boundary = transform_types[strategy]["boundary"]
if function_interval is None:
function_interval = radius_boundary
else:
if function_interval[0] < radius_boundary[0] or \
function_interval[1] > radius_boundary[1]:
raise ValueError(
"ceil_robustness_interval range exceed! expect {},got {}.".format(
radius_boundary, function_interval))
result = function(self, strategy, function_interval, *args, **kwargs)
return result

return _wrap


class Sampling4Evaluation:
"""Sampling data for Evaluation."""

def __init__(self, transform_types):
self.transform_types = transform_types

def _func(self, ori_data, strategy, degree):
"""
Called only by sampling_with_perturbation, it is a thread-level
loop unit.
"""
transform = self.transform_types[strategy]["func"](degree)
mutate_sample = transform(ori_data)
return mutate_sample, degree

@_check_parameter
def sample(self, strategy, function_interval, ori_data,
mutate_num_per_data, *, workers="auto"):
"""
From a source data, sample n data with perturbation ceil_robustness_interval.

Args:
strategy(str): Mutate strategy, a key from transform_types.
function_interval(tuple): perturbation boundary.
ori_data(Union[np.ndarray): origin data.
mutate_num_per_data(int): number of images in the image set.
workers(Union[int,str]):number of threads.
Returns:
List[np.ndarray], a list of generated data.
List[float], a list of interference amplitude to interfere data.
"""
if mutate_num_per_data == 0:
return [], []
if workers == "auto":
workers = max(1, int(0.35 * mutate_num_per_data))
degrees = linspace(*function_interval, num=mutate_num_per_data)
with ThreadPoolExecutor(max_workers=workers) as pool:
results = [pool.submit(self._func, ori_data,
strategy, degree) for
degree in degrees]
results = [item.result() for item in results]
mutate_samples, degrees = list(map(list, zip(*results)))
return mutate_samples, degrees


class SupportImageGenerator:
"""Generate support images when during model robustness evaluation."""

def __init__(self, attack_obj, workers, method, certificate_number,
strategies, mutate_print):
self.workers = workers
self.method = method
self.certificate_number = certificate_number
self.strategies = strategies
self.mutate_print = mutate_print
self.data_sampler = Sampling4Evaluation(attack_obj.transform_types)

@staticmethod
@catch_exception
def _calculate_boundary(score_obj: Score, certificate_number: int):
"""Update robustness boundary before more precise calculation."""
if score_obj.deviation is None:
score_obj.deviation = score_obj.boundary_amplitude / certificate_number
if score_obj.ceil_robustness_interval[1] > 0:
down_ = score_obj.ceil_robustness_interval[1] - score_obj.deviation
else:
down_ = score_obj.ceil_robustness_interval[1]
down_ = max(0, down_)

if score_obj.ceil_robustness_interval[0] < 0:
up_ = score_obj.ceil_robustness_interval[0] + score_obj.deviation
else:
up_ = score_obj.ceil_robustness_interval[0]
up_ = min(0, up_)

radius_dict = {
"Rough": score_obj.boundary,
"down": (score_obj.ceil_robustness_interval[0], up_),
"up": (down_, score_obj.ceil_robustness_interval[1]),
}
return radius_dict

@catch_exception
def _process_single_parameter(self, single_parameter, batch_index):
"""Generate support data from one single data."""
strategies = self.strategies
certificate_number = self.certificate_number
method = self.method
support_data = []
ori_data = single_parameter.get_data()
single_parameter.support_image_info = {
"batch_index": batch_index,
"origin_locate": len(support_data),
"mutate": []}
support_data.append(ori_data)
for strategy in strategies:
score_obj = single_parameter.robustness_results.get(strategy)
radius_dict = self._calculate_boundary(
score_obj,
certificate_number=certificate_number)
res = self.data_sampler.sample(**dict(strategy=strategy,
function_interval=
radius_dict[method],
mutate_num_per_data=certificate_number,
ori_data=ori_data))
generated_data_list, degrees = res
single_parameter.support_image_info["mutate"].append({
"strategy": strategy,
"locate": len(support_data),
"degrees": degrees
})
support_data.extend(generated_data_list)
return single_parameter, support_data

@catch_exception
def _process_parameter_batch(self, parameter_batch, batch_index,
data_batch_queue, preprocess_func):
"""Use multiprocess model to process a batch of data."""
pool = multiprocessing.Pool(self.workers)
func = partial(self._process_single_parameter, batch_index=batch_index)
processed_parameters = pool.map_async(func, parameter_batch)
processed_parameter_batch, support_image_list = list(
zip(*processed_parameters.get()))
support_data = reduce(lambda x, y: x + y, support_image_list)
processed_images = pool.map_async(preprocess_func, support_data)
pool.close()
pool.join()
if not self.mutate_print:
print("batch_data_process_done")
data_batch_queue.put(
(processed_parameter_batch, processed_images.get()))

@catch_exception
def generate(self, parameters, batch_size, preprocess_func,
thread_workers, data_batch_queue):
"""
Use several thread which will generate several processes to generate
support images. After that, send terminal signal to queue.
"""
parameter_batches = [
parameters[i * batch_size: (i + 1) * batch_size]
for i in range(math.ceil(len(parameters) / batch_size))]
print("number_parameters:", len(parameters))
print("number_batches:", len(parameter_batches))
print("batch_size:", batch_size * (
self.certificate_number * len(self.strategies)) + 1)
with ThreadPoolExecutor(max_workers=thread_workers) as pool:
for batch_index, parameter_batch in enumerate(
parameter_batches):
pool.submit(self._process_parameter_batch,
parameter_batch, batch_index,
data_batch_queue, preprocess_func)
data_batch_queue.put("end")
if not self.mutate_print:
print("***************data_batch end put*****************")


class ParseRobustness:
"""Parse Robustness result."""

def __init__(self, postprocess, certificate_number, method, num_strategies,
equal_method):
self.certificate_number = certificate_number
self.equal_method = equal_method
self.postprocess = postprocess
self.method = method
self.num_strategies = num_strategies

@catch_exception
def calculate(self, predict_res_queue):
"""Create multiprocess task to calculate robustness of AI model."""
parameters = ParameterList()
pool = multiprocessing.Pool(3)
results = []
while True:
data = predict_res_queue.get()
if data == "end":
break
else:
results.append(
pool.apply_async(self._process_param_batch, args=(data,)))
for result in results:
parameters.extend(result.get())
return parameters

@catch_exception
def _process_supports(self, single_parameter, robustness_results,
base, pred_results, original_prediction):
"""Calculate boundary according results of supporting image data."""
for item in single_parameter.support_image_info["mutate"]:
strategy = item.get("strategy")
relative_locate = item.get("locate")
degrees = item.get("degrees")
locate = relative_locate + base
mutate_prediction = pred_results[
locate:locate + self.certificate_number]
mutate_corrects = self.equal_method.is_equal(
mutate_prediction,
original_prediction,
method="multi2one")
zipped = list(zip(mutate_corrects, degrees))
left = list(filter(lambda x: not x[0] and x[1] < 0, zipped))
right = list(filter(lambda x: not x[0] and x[1] > 0, zipped))
if left:
down = left[-1][1]
else:
down = degrees[0]
robustness_results[strategy].deviation = 0
if right:
up = right[0][1]
else:
up = degrees[-1]
robustness_results[strategy].deviation = 0
new_radius_dict = {
"Rough": (down, up),
"down": (down,
robustness_results[strategy].ceil_robustness_interval[
1]),
"up": (
robustness_results[strategy].ceil_robustness_interval[0],
up),
}
robustness_results[strategy].ceil_robustness_interval = \
new_radius_dict[
self.method]
return robustness_results

@catch_exception
def _process_param_batch(self, data):
"""process batch data."""
parameter_batch, res = data
check_param_type("model_predict_result", res,
(np.ndarray, int, float, list))
pred_results = self.postprocess(res)
parameter_list = ParameterList()
for i, single_parameter in enumerate(parameter_batch):
robustness_results = copy.deepcopy(
single_parameter.robustness_results)
base = i * (1 + self.num_strategies * self.certificate_number)
relative_locate = single_parameter.support_image_info[
"origin_locate"]
absolute_locate = relative_locate + base
original_prediction = pred_results[absolute_locate]
correct = self.equal_method.is_equal(
original_prediction,
single_parameter.true_label,
method="one2one")
single_parameter.pred_state = correct
single_parameter.pred_label = original_prediction
robustness_results = self._process_supports(single_parameter,
robustness_results,
base,
pred_results,
original_prediction)
single_parameter.robustness_results = robustness_results
del single_parameter.support_image_info
parameter_list.append(single_parameter)
return parameter_list


@six.add_metaclass(abc.ABCMeta)
class CommonEvaluation:
"""
Evaluate robustness of AI model by testing max interference amplitude within
which results inferred by specific AI model for a particular data will not
change.

Args:
strategies(Union[list[str],tuple[str]):mutate strategies,
include pix_mutate and none_pix_mutate_types.
Pixel-level transform will change
just an input image and will
leave any additional targets such as masks, bounding boxes, and
key points unchanged. The list of pixel-level transform:
pix_mutate: ['Blur','Brightness','Contrast','Noise'].
None_pix transform will simultaneously change
both an input image
as well as additional targets such as masks,
bounding boxes, and
key points. The list of none_pix_mutate_types:
['Rotate','Scale', 'Shear_x', 'Shear_y', 'Translate_x',
'Translate_y'].
certificate_number(int): number of data will be generated to evaluate
AI model robustness.
workers(int): number of multiprocess workers.
thread_workers(int): number of data batch which will be preprocessed.
batch_queue_buffer_size(int): number of queue size.
mutate_print(bool): whether print logs or not.
"""

def __init__(self,
strategies,
certificate_number,
workers,
thread_workers,
batch_queue_buffer_size,
mutate_print,
):
self.strategies = strategies
self.certificate_number = certificate_number
self.workers = workers
self.thread_workers = thread_workers
self.batch_queue_buffer_size = batch_queue_buffer_size
self.mutate_print = mutate_print

self._evaluate_parameter = None

@calculate_time
def evaluate(self, fast=True):
"""
model_robustness evaluating.

Args:
fast(bool): if True, it means more calculating cost to get more
precise results.

Returns:
dict, results of robustness evaluation.
"""
self.evaluate_parameter = self.cal_robustness(self.evaluate_parameter,
fast)
results = self.statistic(self.evaluate_parameter)
return results

@catch_exception
def cal_robustness(self, parameters: ParameterList, fast=True):
"""Calculate robustness of each data."""
parameters.init(self.attack_obj, self.strategies)
methods = ["Rough"] if fast else ["Rough", "down", "up"]
for method in methods:
parameters = self._calculate_acc_boundary(parameters=parameters,
method=method)
return parameters

@catch_exception
def statistic(self, parameters):
"""Statistic all robustness results."""
strategies = self.strategies
robustness_result = {}
corrects = [item.pred_state for item in parameters]
accuracy = corrects.count(True) / len(corrects)
robustness_result["accuracy"] = accuracy
if accuracy == 0:
print("All of your data's prediction is wrong,please check"
"your data true_label and pred_label."
"here is a parameter"
" example:\n{}".format(parameters[0]))

for strategy in strategies:
robustness_range = np.mean(
[item.robustness_results[strategy].robustness_interval for item
in parameters], axis=0).tolist()
amplitude = np.mean(
[item.robustness_results[strategy].amplitude for
item in parameters])
deviation = np.mean(
[item.robustness_results[strategy].deviation for item in
parameters])
boundary = parameters[0].robustness_results[strategy].boundary
radius_boundary_amplitude = boundary[1] - boundary[0]
robustness_result[strategy] = {
"robustness_range": robustness_range,
"amplitude": amplitude,
"boundary": boundary,
"deviation": deviation,
"robustness_rate": round(amplitude / radius_boundary_amplitude,
4)}
return robustness_result

@calculate_time
@catch_exception
def _calculate_acc_boundary(self, parameters: ParameterList, method: str):
"""
Calculate the boundaries of the correct image_classification of data points.
"""
num_strategies = len(self.strategies)
model_predict_batch = self.model_adapter.predict_batch_size
number = self.certificate_number
batch_size = max(model_predict_batch // (number * num_strategies), 1)
if batch_size > model_predict_batch:
warnings.warn("Please notice that the real batch size is"
" {}, which is larger than what you set and may cause"
" memory leakage.".format(batch_size))
postprocess_func = self.model_adapter.postprocess
preprocess_func = self.model_adapter.preprocess
support_image_generator = SupportImageGenerator(
self.attack_obj, self.workers, method, number,
self.strategies, self.mutate_print)
data_batch_queue = queue.Queue(self.batch_queue_buffer_size)
predict_res_queue = queue.Queue(self.batch_queue_buffer_size)
post_process = ParseRobustness(postprocess_func, number,
method, num_strategies,
self.equal_method)
batch_process = custom_threading.Thread(
target=support_image_generator.generate,
args=(parameters, batch_size, preprocess_func,
self.thread_workers, data_batch_queue))
model_process = custom_threading.Thread(
target=self.model_adapter.predict_server,
args=(data_batch_queue, predict_res_queue))
post_process = custom_threading.Thread(
target=post_process.calculate, args=(predict_res_queue,))
batch_process.start()
model_process.start()
post_process.start()
batch_process.join()
model_process.join()
parameters = post_process.get_results()
return parameters

@property
def evaluate_parameter(self) -> ParameterList:
if self._evaluate_parameter is None:
raise RuntimeError("please use methods from data_reader property "
"to load data.")
return self._evaluate_parameter

@evaluate_parameter.setter
def evaluate_parameter(self, parameter: ParameterList):
self._evaluate_parameter = parameter

def set_model(self, *args, **kwargs) -> None:
"""Modify ModelAdapters."""

@property
@abc.abstractmethod
def data_reader(self) -> AbstractDataReader:
"""Return a instance of AbstractDataReader."""

@property
@abc.abstractmethod
def model_adapter(self) -> AbstractModelAdapter:
"""Return a sub instance of AbstractModelAdapter."""

@property
@abc.abstractmethod
def equal_method(self) -> AbstractEqual:
"""Return a sub instance of AbstractEqual."""

@property
@abc.abstractmethod
def attack_obj(self) -> AbstractAttack:
"""Return a sub instance of AbstractAttack."""

+ 306
- 0
mindarmour/natural_robustness/base/data_structure.py View File

@@ -0,0 +1,306 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Provide Score object and parameter obj to record robustness information of
single data.
"""

import abc
import math
import pprint
import random

import six
import cv2
import numpy as np

from mindarmour.utils._check_param import check_param_type
from ..utils.tools import typed_property


class Score:
"""Robustness Score."""

def __init__(self, ceil_robustness_interval):
self._robustness_score = None
self._amplitude = None
self._boundary = None
self._boundary_amplitude = None
self._deviation = None
self._floor_robustness_interval = None
self.ceil_robustness_interval = ceil_robustness_interval

@property
def ceil_robustness_interval(self):
"""robustness interval."""
return self._ceil_robustness_interval

@ceil_robustness_interval.setter
def ceil_robustness_interval(self, value):
"""
When robustness interval is set, robustness score, amplitude
is calculated.
"""
self._ceil_robustness_interval = value
if self._deviation is not None:
interval = (
value[0] + 0.5 * self._deviation,
value[1] - 0.5 * self._deviation)
floor_interval = (
value[0] + self._deviation, value[1] - self._deviation)
else:
interval = value
floor_interval = value
if self._boundary is None:
self._boundary = value
self._boundary_amplitude = value[1] - value[0]
self._robustness_interval = interval
self._floor_robustness_interval = floor_interval
self._amplitude = interval[1] - interval[0]
self._robustness_score = self._amplitude / self._boundary_amplitude

@property
def robustness_interval(self):
return self._robustness_interval

@property
def boundary(self):
"""A closed interval, which stand for interference range."""
return self._boundary

@property
def amplitude(self):
"""Inference amplitude."""
return self._amplitude

@property
def boundary_amplitude(self):
"""The amplitude of boundary."""
return self._boundary_amplitude

@property
def robustness_score(self):
"""
The ratio of the interval length to the length of the total
robustness interval.
"""
return self._robustness_score

@property
def deviation(self):
return self._deviation

@deviation.setter
def deviation(self, value):
self._deviation = value

def __repr__(self):
attr_dict = {key.lstrip("_"): value for key, value in
self.__dict__.items()}
string_list = ["{}:{}".format(key, value) for key, value in
attr_dict.items()]
string = "Score-> " + ", ".join(string_list)
return string


@six.add_metaclass(abc.ABCMeta)
class CommonParameter:
"""Basic implement of Parameter. witch is used to record robustness info."""
pred_state = typed_property("pred_state", bool)
strategy_sample_numbers = typed_property("sample_number", dict)
total_sample_number = typed_property("total_sample_number", int)
init_sample_number = typed_property("init_sample_number", int)
support_image_info = typed_property("support_image_info", dict)
inputs = typed_property("inputs", np.ndarray)
robustness_results = typed_property("robustness_results", dict)

def __init__(self):
self._pred_label = None
self._true_label = None

def modify_sample_number(self):
"""
In an AI model robustness improvement circle, if the inference result
of the AI model on certain data is incorrect, the number of sampling
times of the data will be increased. Otherwise, the number of sampling
times will be decreased.
"""
sample_numbers = {}
factor = 1 if self.pred_state else -1
for strategy, score in self.robustness_results.items():
rate = math.exp(-score.robustness_score * factor)
sample_numbers[strategy] = int(
self.init_sample_number * rate) // 2 * 2
self.strategy_sample_numbers = sample_numbers
self.total_sample_number = sum(
value for value in sample_numbers.values())
print(
f"*********{str(sample_numbers)}***{self.average_score}****{self.image_path}*****")

def get_data(self):
"""Using this method to get data during a robustness evaluation process."""
raise NotImplementedError

def init(self, boundaries, strategies):
"""Init boundary for each item."""
self.robustness_results = {
strategy: Score(boundaries[strategy]) for strategy in
strategies}

@property
def average_score(self):
out = sum(item.robustness_score for item in
self.robustness_results.values()) / len(
self.robustness_results)
factor = 1 if self.pred_state else -1
out *= factor
return out

@property
def pred_label(self):
return self._pred_label

@pred_label.setter
@abc.abstractmethod
def pred_label(self, value):
self._pred_label = value

@property
def true_label(self):
return self._true_label

@true_label.setter
@abc.abstractmethod
def true_label(self, value):
self._true_label = value

@support_image_info.deleter
def support_image_info(self):
del self._support_image_info

def __repr__(self):
attr_dict = {key.lstrip("_"): value for key, value in
self.__dict__.items()}
return "Parameter" + pprint.pformat(attr_dict)


class ParameterList:
"""Contain a list of Parameter obj."""

def __init__(self, store=None):
self._storage = store if store else []
self._data_number = 0

def __add__(self, other):
""" Return self+value. """
self._storage = self._storage + other
return self

def __getitem__(self, item):
res = self._storage[item]
if isinstance(res, list):
obj = self.__class__()
obj.storage = res
else:
obj = res
return obj

def __repr__(self):
return self._storage.__repr__()

def __len__(self):
return len(self._storage)

@property
def is_empty(self):
"""Is empty."""
return not self._storage

@property
def storage(self):
"""storage."""
return self._storage

@storage.setter
def storage(self, iterable):
"""storage setter."""
self._storage = iterable

def append(self, item):
""" Append object to the end of the list. """
self._storage.append(item)
return self

def extend(self, iterable):
"""Extend list by appending elements from the iterable."""
if isinstance(iterable, ParameterList):
iterable = iterable.storage
self._storage.extend(iterable)
return self

def modify_sample_number(self):
"""See CommonParameter."""
for item in self._storage:
item.modify_sample_number()

def init(self, attack, strategies):
"""Init boundary for each item."""
boundaries = attack.transform_method_boundary
for item in self._storage:
item.init(boundaries, strategies)

def init_sample_number(self, sample_number):
"""Init sample number."""
for item in self._storage:
item.init_sample_number = sample_number

def random_sample(self, keep_rate):
"""Keep a certain rate of data to amplify data during retraining."""
k = int(len(self) * keep_rate)
indexes = []
unchoose = []
for index, item in enumerate(self._storage):
if item.total_sample_number != 0:
indexes.append(index)
else:
unchoose.append(index)
print("number of robust data is {}.".format(len(unchoose)))
random.shuffle(indexes)
random.shuffle(unchoose)
indexes = unchoose[:int(0.2 * len(unchoose))] + indexes
random.shuffle(indexes)
total = indexes + unchoose[int(0.2 * len(unchoose)):]
choose = total[:k]
unchoose = total[k:]
new_storage = [self.storage[index] for index in choose]
self.storage = [self.storage[index] for index in unchoose]
new_obj = self.__class__()
new_obj.storage = new_storage
return new_obj


def get_image_data(parameter):
"""
Classification and objectdetection task will use this method to get data.
"""
if parameter.inputs is not None:
ori_image = parameter.inputs
else:
image_path = check_param_type("path", parameter.image_path, str)
ori_image = cv2.imread(image_path)
ori_image = check_param_type("image_data", ori_image, np.ndarray)
if parameter.target_shape is not None:
ori_image = cv2.resize(ori_image, dsize=parameter.target_shape)
return ori_image

+ 163
- 0
mindarmour/natural_robustness/base/read_data.py View File

@@ -0,0 +1,163 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Abstract DataReader."""

import abc
import itertools
import multiprocessing
import os
import shutil
from concurrent.futures import ThreadPoolExecutor

import cv2


from ..utils.tools import catch_exception, sinhspace


class SampleData4Defense:
"""
Used in defense process. Sample data for retraining AI model.
"""

def __init__(self, save_dir, workers, transform_types):
self.save_dir = save_dir
self.workers = workers
self.transform_types = transform_types

def generate_anno_path(self,parameters, results, data_dir, train_data_dir):
raise NotImplementedError

def write(self, strategies, parameters, defense_data_dir):
"""
According to the calculated data, calculate and generate a bunch of
increased sample data sets

Args:
strategies(list[str]): Mutate strategies, a list of key from
pix_mutate or none_pix_mutate_types.
parameters(list[_Parameter]): A list of Parameter instance.
defense_data_dir(str): A data directory name where store data
which is used in defense model from natural interference.

Returns:
str, storage location of newly generated data.
str, annotations of newly generated data.
"""
data_dir = os.path.join(self.save_dir, "image")
if os.path.exists(data_dir):
shutil.rmtree(data_dir)
os.makedirs(data_dir)
cpu_num = min(multiprocessing.cpu_count(), self.workers)
pool = multiprocessing.Pool(cpu_num)
results = []
for strategy, parameter in itertools.product(strategies, parameters):
transform_func = self.transform_types[strategy]["func"]
sub_directory = self.sub_directory_func(parameter, data_dir)
result = pool.apply_async(self.sample_single_image,
args=(strategy, parameter, sub_directory,
transform_func),
)
results.append(result)
results = [item.get() for item in results]
anno_path = self.generate_anno_path(parameters, results, data_dir,
defense_data_dir)
pool.close()
pool.join()
pool.terminate()
return data_dir, anno_path

@catch_exception
def sample_single_image(self, strategy, parameter, data_dir,
transform_func):
"""
Using a specific sample increase strategy to sample data.

Args:
strategy(str): Mutate strategies, a key from pix_mutate or
none_pix_mutate_types.
parameter(Parameter): Instance of Parameter,
contains essential Information of single data on each circle.
data_dir(str): A directory where newly generated image saved.
transform_func(Callable): Functions to transfer image.

Returns:
List(str), a list of image paths which are newly generated.
List(str), a list of amplitudes corresponding to the picture.
"""

@catch_exception
def _func(transform_func, strategy, degree, ori_image_path,
save_dir):
"""Smallest process unit to sample data."""
ori_image = cv2.imread(ori_image_path)
ori_image_name = os.path.basename(ori_image_path)
prefix, suffix = os.path.splitext(ori_image_name)
new_name = f"{prefix}_{strategy}_{round(degree, 3)}{suffix}"
mutate_sample_path = os.path.join(save_dir, new_name)

transform = transform_func(degree)
mutate_sample = transform(ori_image)
if not os.path.exists(mutate_sample_path):
cv2.imwrite(mutate_sample_path, mutate_sample)
return mutate_sample_path, degree

interval = parameter.robustness_results[
strategy].ceil_robustness_interval
sample_number = parameter.strategy_sample_numbers[strategy]
if sample_number == 0:
return [], []
workers = max(1, int(0.35 * sample_number))
degrees = sinhspace(interval[0], interval[1], num=sample_number)
with ThreadPoolExecutor(max_workers=workers) as pool:
results = [
pool.submit(_func, transform_func, strategy, degree,
parameter.image_path, data_dir) for degree in
degrees]
results = [item.result() for item in results]
results = sorted(results, key=lambda x: x[1])
mutate_sample_objs, degrees = list(map(list, zip(*results)))
pool.shutdown()
return mutate_sample_objs, degrees

@abc.abstractmethod
def sub_directory_func(self, parameter, data_dir):
"""
Functions for build data store structure. Usually, in classification
task, we use label as sub directory name.
"""


class AbstractDataReader:
"""load data from different source."""

def __init__(self, evaluator, defense):
self.evaluator = evaluator
self.defense = defense

self.defense_data_dir = None
self.evaluate_data_dir = None

def get_anno_path(self, data_dir) -> str:
"""Generate anno txt file from data directory."""

def from_directory(self, evaluate_data_dir, defense_data_dir=None,
shuffle=True,
target_shape=None) -> None:
"""Generate specific data structures to store infos for each data."""

def write(self, save_dir, workers, transform_types, strategies,
parameters):
"""Write data form specific data directory."""

+ 2
- 1
mindarmour/natural_robustness/transform/__init__.py View File

@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image Transform Method for Natural Robustness.
This package include methods to generate natural perturbation samples for different
kind of data.
"""

+ 23
- 0
mindarmour/natural_robustness/transform/time_series/__init__.py View File

@@ -0,0 +1,23 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This package include methods to generate natural perturbation samples for time
series data.
"""
from .corruption import Miss, Noise

__all__ = [
"Miss",
"Noise"
]

+ 147
- 0
mindarmour/natural_robustness/transform/time_series/corruption.py View File

@@ -0,0 +1,147 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Time series corruption.
"""
import copy

import numpy as np

from mindarmour.utils._check_param import check_param_multi_types
from .natural_perturb import _NaturalPerturb


class Miss(_NaturalPerturb):
"""
Drop some information from original data.

Args:
factor (float): factor is the ratio of pixels to add noise.
If 0 gives the original data. Default 0.

Examples:
>>> x = np.linspace(0, 20, 100)
>>> y = np.sin(x)
>>> trans = Miss(0.2)
>>> target = trans(y)
>>> plt.subplot(2, 1, 1)
>>> plt.scatter(x, y, c="r", s=1)
>>> plt.legend(["origin"])
>>> plt.subplot(2, 1, 2)
>>> plt.scatter(x, target, s=1)
>>> plt.legend(["miss"])
>>> plt.show()
"""

def __init__(self, factor=0):
super(Miss, self).__init__()
self.set_params(factor)

def set_params(self, factor=0, auto_param=False):
"""
Set noise parameters.

Args:
factor (Union[float, int]): factor is the ratio of pixels to
add noise. If 0 gives the original data. Default 0.
auto_param (bool): True if auto generate parameters.
Default: False.
"""
if auto_param:
factor = np.random.uniform(0, 1)
else:
factor = check_param_multi_types('factor', factor,
[int, float])
self.factor = abs(factor)

def __call__(self, data):
"""
Transform the time_series_data.

Args:
data (numpy.ndarray): Original time series to be transformed.

Returns:
numpy.ndarray, transformed time series.
"""
ori_dtype = data.dtype
np.random.seed(3)
noise = np.random.uniform(low=-1, high=1, size=np.shape(data))
trans_data = copy.deepcopy(data)
threshold = 1 - self.factor
trans_data[noise < -threshold] = 0
trans_data[noise > threshold] = 0
return trans_data.astype(ori_dtype)


class Noise(_NaturalPerturb):
"""
Add noise of a time_series_data.

Args:
factor (float): factor is the ratio of pixels to add noise.
If 0 gives the original data. Default 0.

Examples:
>>> x = np.linspace(0, 100, 500)
>>> y = np.sin(x)
>>> trans = Noise(0.05)
>>> target = trans(y)
>>> plt.subplot(2, 1, 1)
>>> plt.plot(x, y, c="r")
>>> plt.legend(["origin"])
>>> plt.subplot(2, 1, 2)
>>> plt.plot(x, target)
>>> plt.legend(["noise"])
>>> plt.show()
"""

def __init__(self, factor=0):
super(Noise, self).__init__()
self.set_params(factor)

def set_params(self, factor=0, auto_param=False):
"""
Set noise parameters.

Args:
factor (Union[float, int]): factor is the ratio of pixels to
add noise. If 0 gives the original data. Default 0.
auto_param (bool): True if auto generate parameters.
Default: False.
"""
if auto_param:
factor = np.random.uniform(0, 1)
else:
factor = check_param_multi_types('factor', factor,
[int, float])
self.factor = abs(factor)

def __call__(self, data):
"""
Transform the time series data.

Args:
data (numpy.ndarray): Original time series to be transformed.

Returns:
numpy.ndarray, transformed data.
"""
ori_dtype = data.dtype
noise = np.random.uniform(low=-1, high=1, size=np.shape(data))
trans_data = copy.deepcopy(data)
threshold = 1 - self.factor
trans_data[noise < -threshold] = -1
trans_data[noise > threshold] = 1
return trans_data.astype(ori_dtype)

+ 33
- 0
mindarmour/natural_robustness/transform/time_series/natural_perturb.py View File

@@ -0,0 +1,33 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for time series data natural perturbation.
"""

from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Time Series Transformation'


class _NaturalPerturb:
"""
The abstract base class for all natural perturbation classes.
"""

def __init__(self):
pass

def __call__(self, image):
pass

+ 13
- 0
mindarmour/natural_robustness/utils/__init__.py View File

@@ -0,0 +1,13 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.ffd (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

+ 36
- 0
mindarmour/natural_robustness/utils/custom_threading.py View File

@@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom threading. """

import threading


class Thread(threading.Thread):
"""Rewrite the thread in the official hreading so that the return value
of the function can be obtained."""

def run(self):
"""Method representing the thread's activity."""
try:
if self._target:
self._results = self._target(*self._args, **self._kwargs)
finally:
# Avoid a refcycle if the thread is running a function with
# an argument that has a member that points to the thread.
del self._target, self._args, self._kwargs

def get_results(self, timeout=None):
self.join(timeout=timeout)
return self._results

+ 229
- 0
mindarmour/natural_robustness/utils/tools.py View File

@@ -0,0 +1,229 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""robustness utils."""

import os
import sys
import time
import traceback
from functools import wraps
from shutil import copy2, copy, copystat, Error

import numpy as np

linspace = np.linspace


def sinhspace(start, stop, num=50):
"""
Generate a sinh random number column between start and stop.

Args:
start(Union[int, float]): The starting value of the sequence.
stop(Union[int, float]): The end value of the sequence.
num(int): Number of samples to generate. Default is 50.
Must be non-negative.

Returns:
x, a sinh random number column between start and stop.
"""
amplitude = 5
max_value = max(abs(start), abs(stop))
start = np.sinh(start / max_value * amplitude)
stop = np.sinh(stop / max_value * amplitude)
x = np.linspace(start, stop, num)
x = np.arcsinh(x) * max_value / amplitude
return x


def hashable(obj):
"""
Determine whether an object can be hashed.
Args:
obj:object
Returns:
bool,
"""
return obj.__hash__ is not None


def catch_exception(function):
"""
Show exceptions when exceptions occurred in child threads or subprocess.
Usually, Exceptions are not being traced and printed in stdout when they
not occur in main threads or process, which makes it hard to locate
error.
"""

@wraps(function)
def _wrap(*args, **kwargs):
try:
result = function(*args, **kwargs)
except Exception:
traceback.print_exc(file=sys.stdout)
raise
return result

return _wrap


def label_property4objectdetection(name, self):
"""create property for parameter label."""
storage_name = '_' + name
if self is not None:
setattr(self, storage_name, None)

@property
def prop(self):
"""get property."""
return getattr(self, storage_name, None)

@prop.setter
def prop(self, value):
"""set property."""
if not hashable(value):
value = sorted(value)
setattr(self, storage_name, value)

return prop


def typed_property(name, expected_type):
"""create property for class and check types."""
storage_name = '_' + name

@property
def prop(self):
return getattr(self, storage_name, None)

@prop.setter
def prop(self, value):
if isinstance(value, expected_type):
setattr(self, storage_name, value)
elif value is not None:
raise TypeError(
'{} must be a {}, but got {} with type {}'.format(name,
expected_type,
value,
type(value)))

return prop


def validated_property(name):
"""create property for class."""
storage_name = '_' + name

@property
def prop(self):
res = getattr(self, storage_name, None)
if res is None:
raise RuntimeError("please use 'compile' method first.")
return res

return prop


def _copytree(entries, src, dst, symlinks, ignore, copy_function,
ignore_dangling_symlinks, dirs_exist_ok=False):
"""See copytree."""
if ignore is not None:
ignored_names = ignore(os.fspath(src), [x.name for x in entries])
else:
ignored_names = set()

os.makedirs(dst, exist_ok=dirs_exist_ok)
errors = []
use_srcentry = copy_function is copy2 or copy_function is copy

for srcentry in entries:
if srcentry.name in ignored_names:
continue
srcname = os.path.join(src, srcentry.name)
dstname = os.path.join(dst, srcentry.name)
srcobj = srcentry if use_srcentry else srcname
try:
is_symlink = srcentry.is_symlink()
if is_symlink and os.name == 'nt':
# Special check for directory junctions, which appear as
# symlinks but we want to recurse.
lstat = srcentry.stat(follow_symlinks=False)
if lstat.st_reparse_tag == 2684354563:
is_symlink = False
if is_symlink:
linkto = os.readlink(srcname)
if symlinks:
# We can't just leave it to `copy_function` because legacy
# code with a custom `copy_function` may rely on copytree
# doing the right thing.
os.symlink(linkto, dstname)
copystat(srcobj, dstname, follow_symlinks=not symlinks)
else:
# ignore dangling symlink if the flag is on
if not os.path.exists(linkto) and ignore_dangling_symlinks:
continue
# otherwise let the copy occur. copy2 will raise an error
if srcentry.is_dir():
copytree(srcobj, dstname, symlinks, ignore,
copy_function, dirs_exist_ok=dirs_exist_ok)
else:
copy_function(srcobj, dstname)
elif srcentry.is_dir():
copytree(srcobj, dstname, symlinks, ignore, copy_function,
dirs_exist_ok=dirs_exist_ok)
else:
# Will raise a SpecialFileError for unsupported file types
copy_function(srcobj, dstname)
# catch the Error from the recursive copytree so that we can
# continue with other files
except Error as err:
errors.extend(err.args[0])
except OSError as why:
errors.append((srcname, dstname, str(why)))
try:
copystat(src, dst)
except OSError as why:
# Copying file access times may fail on Windows
if getattr(why, 'winerror', None) is None:
errors.append((src, dst, str(why)))
if errors:
raise Error(errors)
return dst


def copytree(src, dst, symlinks=False, ignore=None, copy_function=copy2,
ignore_dangling_symlinks=False, dirs_exist_ok=False):
"""Recursively copy a directory tree and return the destination directory."""

with os.scandir(src) as itr:
entries = list(itr)
return _copytree(entries=entries, src=src, dst=dst, symlinks=symlinks,
ignore=ignore, copy_function=copy_function,
ignore_dangling_symlinks=ignore_dangling_symlinks,
dirs_exist_ok=dirs_exist_ok)


def calculate_time(func):
"""Decorator that reports the execution time."""

@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(func.__name__, end - start)
return result

return wrapper

+ 0
- 0
tests/ut/python/natural_robustness/classification/__init__.py View File


+ 60
- 0
tests/ut/python/natural_robustness/classification/test_defense.py View File

@@ -0,0 +1,60 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defense ut."""

import pprint

import pytest

from examples.natural_robustness.applications.classification.preparation import \
create_model, get_file, image_process_ops, postprocess_func, \
preprocess_func
from mindarmour.natural_robustness.applications.image_classification import \
Defense

model = create_model("mobile_net.ckpt")
train_data_dir, test_data_dir = get_file()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_defense():
"""
Feature: Test defense function for AI model when natural interference
happened in original data.
Description: Testing max interference amplitude within which results inferred by
specific AI model for a particular data will not change. Then sampling
data according robustness evaluating results. After that retrain AI
model by using sampled data to reach the goal of improve robustness of
AI model.
Expectation: success.
"""
defense = Defense(
strategies=['GaussianBlur', 'Brightness', 'Contrast', 'UniformNoise',
'GaussianNoise', 'SaltAndPepper', 'Rotate', 'Scale',
'Shear_x', 'Shear_y', 'Translate_x', 'Translate_y'],
certificate_number=10,
augmentation_image_dir="aug_data",
name2index={'dogs': 0, 'wolves': 1},
batch_queue_buffer_size=2, mutate_print=False,
workers=2, thread_workers=2)
defense.set_model(model, predict_batch_size=100, train_func=None,
operations=image_process_ops(True),
device_target="CPU", preprocess_func=preprocess_func,
postprocess_func=postprocess_func, epoch=2)
defense.data_reader.from_directory(test_data_dir, train_data_dir, True)
result = defense.defense(sample_rate=0.2, sample_number=4, iter_number=2)
pprint.pprint(result)
return 0

+ 56
- 0
tests/ut/python/natural_robustness/classification/test_evaluation.py View File

@@ -0,0 +1,56 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation ut."""

import pytest

from examples.natural_robustness.applications.classification.preparation import \
create_model, get_file, postprocess_func, preprocess_func
from mindarmour.natural_robustness.applications.image_classification import \
Evaluation



@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_evaluation():
"""
Feature: Test AI model evaluation.
Description: Evaluate robustness of AI model by testing max interference amplitude within
which results inferred by specific AI model for a particular data will not change.
Expectation: success.
"""
model = create_model("./mobile_net.ckpt")
train_data_dir, test_data_dir = get_file()
evaluator = Evaluation(
strategies=['GaussianBlur', 'Brightness', 'Contrast', 'UniformNoise',
'GaussianNoise', 'SaltAndPepper', 'Rotate', 'Scale',
'Shear_x', 'Shear_y', 'Translate_x', 'Translate_y'],
support_data_number=10,
name2index={'dogs': 0, 'wolves': 1},
workers=2,
thread_workers=2,
batch_queue_buffer_size=2,
mutate_print=False)
evaluator.set_model(model, predict_batch_size=100,
preprocess_func=preprocess_func,
postprocess_func=postprocess_func)
evaluator.data_reader.from_directory(test_data_dir,
shuffle=True,
target_shape=None)
result = evaluator.evaluate()
print(evaluator.evaluate_parameter[0])
print(result)

+ 46
- 0
tests/ut/python/natural_robustness/test_time_series_transform.py View File

@@ -0,0 +1,46 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for natural robustness methods on time series data."""

import numpy as np
import pytest

from mindarmour.natural_robustness.transform.time_series import Miss, Noise


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def noise_example():
"""An example for Noise interference in time series data."""
x = np.linspace(0, 100, 500)
y = np.sin(x)
trans = Noise(0.05)
target = trans(y)
assert isinstance(target,np.ndarray)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def miss_example():
"""An example for Miss interference in time series data."""
x = np.linspace(0, 20, 100)
y = np.sin(x)
trans = Miss(0.2)
target = trans(y)
assert isinstance(target,np.ndarray)

Loading…
Cancel
Save